diff --git a/.github/.keep b/.github/.keep
new file mode 100644
index 0000000..e69de29
diff --git a/.github/workflows/auto-fix.yml b/.github/workflows/auto-fix.yml
new file mode 100644
index 0000000..dcbf186
--- /dev/null
+++ b/.github/workflows/auto-fix.yml
@@ -0,0 +1,68 @@
+name: 👾 auto-fix code quality
+
+on:
+ pull_request:
+ branches: [ main, master ]
+ push:
+ branches: [ main, master ]
+ workflow_dispatch: # allow manual triggering
+
+jobs:
+ auto-fix:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: checkout code
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # full history for better diff detection
+
+ - name: set up python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.9'
+ cache: 'pip'
+
+ - name: install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r scripts/requirements.txt
+
+ - name: install code quality tools
+ run: |
+ pip install black isort autopep8 autoflake
+
+ - name: run auto-fix script
+ run: |
+ python scripts/auto_fix.py
+
+ - name: check for changes
+ id: check_changes
+ run: |
+ if [ -n "$(git status --porcelain)" ]; then
+ echo "changes=true" >> $GITHUB_OUTPUT
+ echo "files were modified by auto-fix script"
+ git status --porcelain
+ else
+ echo "changes=false" >> $GITHUB_OUTPUT
+ echo "no files were modified"
+ fi
+
+ - name: commit and push changes (if any)
+ if: steps.check_changes.outputs.changes == 'true'
+ run: |
+ git config --local user.email "action@github.com"
+ git config --local user.name "github action"
+ git add -a
+ git commit -m "🔧 auto-fix code quality issues
+
+ - applied black formatting
+ - organized imports with isort
+ - fixed code style with autopep8
+ - removed unused imports with autoflake
+ - fixed markdown formatting
+ - validated and fixed links
+ - removed trailing whitespace
+
+ auto-generated by github actions"
+ git push
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..7c42ec4
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,204 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Django stuff (keeping in case of web components)
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff (keeping in case of web components)
+instance/
+.webassets-cache
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+Pipfile.lock
+
+# poetry
+poetry.lock
+
+# pdm
+pdm.lock
+.pdm.toml
+
+# PEP 582
+__pypackages__/
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+.idea/
+*.iml
+*.ipr
+*.iws
+
+# VS Code
+.vscode/
+
+# macOS
+.DS_Store
+.AppleDouble
+.LSOverride
+Icon
+._*
+.DocumentRevisions-V100
+.fseventsd
+.Spotlight-V100
+.TemporaryItems
+.Trashes
+.VolumeIcon.icns
+.com.apple.timemachine.donotpresent
+.AppleDB
+.AppleDesktop
+Network Trash Folder
+Temporary Items
+.apdisk
+
+# Windows
+Thumbs.db
+Thumbs.db:encryptable
+ehthumbs.db
+ehthumbs_vista.db
+*.stackdump
+[Dd]esktop.ini
+$RECYCLE.BIN/
+*.cab
+*.msi
+*.msix
+*.msm
+*.msp
+*.lnk
+
+# Linux
+*~
+.fuse_hidden*
+.directory
+.Trash-*
+.nfs*
+
+# Machine Learning specific
+*.pkl
+*.pickle
+*.joblib
+*.h5
+*.hdf5
+*.model
+*.weights
+*.ckpt
+*.pth
+*.pt
+*.onnx
+*.tflite
+*.pb
+
+# Data files
+*.csv
+*.json
+*.parquet
+*.feather
+*.hdf
+*.xlsx
+*.xls
+
+# Large model files
+models/
+checkpoints/
+runs/
+logs/
+wandb/
+
+# Environment variables
+.env.local
+.env.development
+.env.test
+.env.production
+
+# IDE specific
+*.swp
+*.swo
diff --git a/EBMs/README.md b/EBMs/README.md
index 1469a1a..f161419 100644
--- a/EBMs/README.md
+++ b/EBMs/README.md
@@ -1,12 +1,15 @@
-## quantum ai: training energy-based-models using openai
+## quantum ai: training energy-based-models using openAI
-
-#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
+
+#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in
+energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
+---
+
### installing
@@ -19,7 +22,8 @@ brew install pkg-config
-* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this problem (`PMIX ERROR: ERROR`) that can be fixed with:
+* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this
+problem (`PMIX ERROR: ERROR`) that can be fixed with:
@@ -40,7 +44,8 @@ pip install -r requirements.txt
```
-* note that this is an adapted requirement file since the **[openai's original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
+* note that this is an adapted requirement file since the **[openai's
+original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
* finally, download and install **[mujoco](https://www.roboti.us/index.html)**
* you will also need to register for a license, which asks for a machine ID
* the documentation on the website is incomplete, so just download the suggested script and run:
@@ -64,7 +69,8 @@ mv getid_osx getid_osx.dms
-* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder `cachedir`:
+* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder
+`cachedir`:
@@ -78,7 +84,8 @@ mkdir cachedir
-* openai's original code contains **[hardcoded constants that only work on Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
+* openai's original code contains **[hardcoded constants that only work on
+Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
* i changed this to a constant (`ROOT_DIR = "./results"`) in the top of `data.py`
@@ -87,7 +94,8 @@ mkdir cachedir
-* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased substantially by using multiple different workers by running each command:
+* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased
+substantially by using multiple different workers by running each command:
@@ -102,7 +110,8 @@ mpiexec -n
```
-python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model
+python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
+--zero_kl --replay_batch --large_model
```
* this should generate the following output:
@@ -112,7 +121,8 @@ python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_si
```bash
Instructions for updating:
Use tf.gfile.GFile.
-2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
+2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is
+deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
64 batch size
Local rank: 0 1
Loading data...
@@ -121,11 +131,15 @@ Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Done loading...
-WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
+WARNING:tensorflow:From
+/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263:
+colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
Building graph...
-WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
+WARNING:tensorflow:From
+/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from
+tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
Finished processing loop construction ...
@@ -136,16 +150,36 @@ Model has a total of 7567880 parameters
Initializing variables...
Start broadcast
End broadcast
-Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff: 0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent: 2.272536277770996, l
-oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first: 0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818, context_0/res_optim_res_c1/
-cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10, context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10, context_0/res_optim_res_c2/gb:0: 6.27235652306268
-3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11, context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437, context_0/res_1_res_c2/g:0
-: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0: 1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0: 6.868505764145993e-10, context_0/res_2
-_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0: 8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0: 0.0194275379180908
-2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10, context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11, context_0/res_3_res_c2/gb:0: 7.75612463
-1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0: 4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0: 0.34806951880455017, context_0/res_4_res_c2/g:
-0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0: 5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0: 3.807749116013781e-10, context_0/res_5
-_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0: 7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039, context_0/fc5/
+Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff:
+0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent:
+2.272536277770996, l
+oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first:
+0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818,
+context_0/res_optim_res_c1/
+cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10,
+context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10,
+context_0/res_optim_res_c2/gb:0: 6.27235652306268
+3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11,
+context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437,
+context_0/res_1_res_c2/g:0
+: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0:
+1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0:
+6.868505764145993e-10, context_0/res_2
+_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0:
+8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0:
+0.0194275379180908
+2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10,
+context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11,
+context_0/res_3_res_c2/gb:0: 7.75612463
+1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0:
+4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0:
+0.34806951880455017, context_0/res_4_res_c2/g:
+0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0:
+5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0:
+3.807749116013781e-10, context_0/res_5
+_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0:
+7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039,
+context_0/fc5/
b:0: 0.4506262540817261,
................................................................................................................................
@@ -159,7 +193,8 @@ Inception score of 1.2397289276123047 with std of 0.0
```
-python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass
+python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
+--zero_kl --replay_batch --cclass
```
@@ -169,7 +204,8 @@ python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size
```
-python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=
+python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01
+--replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=
```
@@ -179,7 +215,8 @@ python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=3
```
-python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir=
+python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass
+--zero_kl --dataset=imagenetfull --imagenet_datadir=
```
@@ -197,7 +234,8 @@ python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0
python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act
```
-* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by different settings of task flag in the file
+* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by
+different settings of task flag in the file
* for example, to visualize cross class mappings in cifar-10, you can run:
@@ -217,7 +255,8 @@ python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resu
```
-python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False
+python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200
+--large_model --svhnmix --cclass=False
```
@@ -227,7 +266,8 @@ python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_
```
-python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd= --num_steps=10 --lival= --wider_model
+python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd= --num_steps=10 --lival= --wider_model
```
@@ -236,12 +276,14 @@ python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=
-* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in `cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
+* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in
+`cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
```
-python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass
+python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act
+--cond_pos --replay_batch -cclass
```
@@ -249,5 +291,7 @@ python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps
* once models are trained, they can be sampled from jointly by running:
```
-python ebm_combine.py --task=conceptcombine --exp_size= --exp_shape= --exp_pos= --exp_rot= --resume_size= --resume_shape= --resume_rot= --resume_pos=
+python ebm_combine.py --task=conceptcombine --exp_size= --exp_shape= --exp_pos=
+--exp_rot= --resume_size= --resume_shape= --resume_rot=
+--resume_pos=
```
diff --git a/EBMs/ais.py b/EBMs/ais.py
index c22cc6f..1e338d2 100644
--- a/EBMs/ais.py
+++ b/EBMs/ais.py
@@ -1,40 +1,65 @@
-import tensorflow as tf
import math
+import os.path as osp
+
+import numpy as np
+import tensorflow as tf
+from data import Cifar10, DSprites, Mnist
from hmc import hmc
+from models import DspritesNet, MnistNet, ResNet32, ResNet32Large, ResNet32Wider
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader
-from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
-from data import Cifar10, Mnist, DSprites
-from scipy.misc import logsumexp
-from scipy.misc import imsave
-from utils import optimistic_restore
-import os.path as osp
-import numpy as np
from tqdm import tqdm
+from utils import optimistic_restore
-flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
-flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
-flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
-flags.DEFINE_string('exp', 'default', 'name of experiments')
-flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
-flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
-flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
+flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
+flags.DEFINE_string(
+ "dataset", "cifar10", "cifar10 or mnist or dsprites or 2d or toy Gauss"
+)
+flags.DEFINE_string(
+ "logdir", "cachedir", "location where log of experiments will be stored"
+)
+flags.DEFINE_string("exp", "default", "name of experiments")
+flags.DEFINE_integer(
+ "data_workers", 5, "Number of different data workers to load data in parallel"
+)
+flags.DEFINE_integer("batch_size", 16, "Size of inputs")
+flags.DEFINE_string("resume_iter", "-1", "iteration to resume training from")
-flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
-flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
-flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
-flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
-flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
-flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
-flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
-flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
-flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
-flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
-flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
-flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
-flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
-flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
-flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
+flags.DEFINE_bool(
+ "max_pool",
+ False,
+ "Whether or not to use max pooling rather than strided convolutions",
+)
+flags.DEFINE_integer(
+ "num_filters",
+ 64,
+ "number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
+)
+flags.DEFINE_integer("pdist", 10, "number of intermediate distributions for ais")
+flags.DEFINE_integer("gauss_dim", 500, "dimensions for modeling Gaussian")
+flags.DEFINE_integer(
+ "rescale", 1, "factor to rescale input outside of normal (0, 1) box"
+)
+flags.DEFINE_float(
+ "temperature", 1, "temperature at which to compute likelihood of model"
+)
+flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
+flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
+flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
+flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
+flags.DEFINE_bool(
+ "cclass",
+ False,
+ "Whether to evaluate the log likelihood of conditional model or not",
+)
+flags.DEFINE_bool(
+ "single",
+ False,
+ "Whether to evaluate the log likelihood of conditional model or not",
+)
+flags.DEFINE_bool("large_model", False, "Use large model to evaluate")
+flags.DEFINE_bool("wider_model", False, "Use large model to evaluate")
+flags.DEFINE_float("alr", 0.0045, "Learning rate to use for HMC steps")
FLAGS = flags.FLAGS
@@ -45,11 +70,12 @@ label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
def unscale_im(im):
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
+
def gauss_prob_log(x, prec=1.0):
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
- prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
+ prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2.0 * prec
return norm_constant_log + prob_density_log
@@ -73,23 +99,36 @@ def model_prob_log(x, e_func, weights, temp):
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
if FLAGS.dataset == "gauss":
- norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
+ norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(
+ x, prec=FLAGS.temperature
+ )
else:
- norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
- # Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
+ norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * model_prob_log(
+ x, e_func, weights, temp
+ )
+ # Add an additional log likelihood penalty so that points outside of (0,
+ # 1) box are *highly* unlikely
-
- if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
- oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
- elif FLAGS.dataset == 'mnist':
- oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
+ if FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
+ oob_prob = tf.reduce_sum(
+ tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1]
+ )
+ elif FLAGS.dataset == "mnist":
+ oob_prob = tf.reduce_sum(
+ tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1, 2]
+ )
else:
- oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
+ oob_prob = tf.reduce_sum(
+ tf.square(100 * (x - tf.clip_by_value(x, 0.0, FLAGS.rescale))),
+ axis=[1, 2, 3],
+ )
return -norm_prob + oob_prob
-def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
+def ancestral_sample(
+ e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10
+):
if FLAGS.dataset == "2d":
x = tf.placeholder(tf.float32, shape=(None, 2))
elif FLAGS.dataset == "gauss":
@@ -130,41 +169,46 @@ def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_
def main():
# Initialize dataset
- if FLAGS.dataset == 'cifar10':
+ if FLAGS.dataset == "cifar10":
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
channel_num = 3
- dim_input = 32 * 32 * 3
- elif FLAGS.dataset == 'imagenet':
+ 32 * 32 * 3
+ elif FLAGS.dataset == "imagenet":
dataset = ImagenetClass()
channel_num = 3
- dim_input = 64 * 64 * 3
- elif FLAGS.dataset == 'mnist':
+ 64 * 64 * 3
+ elif FLAGS.dataset == "mnist":
dataset = Mnist(train=False, rescale=FLAGS.rescale)
channel_num = 1
- dim_input = 28 * 28 * 1
- elif FLAGS.dataset == 'dsprites':
+ 28 * 28 * 1
+ elif FLAGS.dataset == "dsprites":
dataset = DSprites()
channel_num = 1
- dim_input = 64 * 64 * 1
- elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
+ 64 * 64 * 1
+ elif FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
dataset = Box2D()
- dim_output = 1
- data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
+ data_loader = DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size,
+ num_workers=FLAGS.data_workers,
+ drop_last=False,
+ shuffle=True,
+ )
- if FLAGS.dataset == 'mnist':
+ if FLAGS.dataset == "mnist":
model = MnistNet(num_channels=channel_num)
- elif FLAGS.dataset == 'cifar10':
+ elif FLAGS.dataset == "cifar10":
if FLAGS.large_model:
model = ResNet32Large(num_filters=128)
elif FLAGS.wider_model:
model = ResNet32Wider(num_filters=192)
else:
model = ResNet32(num_channels=channel_num, num_filters=128)
- elif FLAGS.dataset == 'dsprites':
+ elif FLAGS.dataset == "dsprites":
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
- weights = model.construct_weights('context_{}'.format(0))
+ weights = model.construct_weights("context_{}".format(0))
config = tf.ConfigProto()
sess = tf.Session(config=config)
@@ -173,8 +217,8 @@ def main():
sess.run(tf.global_variables_initializer())
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
- model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
- resume_itr = FLAGS.resume_iter
+ model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
+ FLAGS.resume_iter
if FLAGS.resume_iter != "-1":
optimistic_restore(sess, model_file)
@@ -182,14 +226,17 @@ def main():
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
# saver.restore(sess, model_file)
- chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
+ chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(
+ model, weights, FLAGS.batch_size, temp=FLAGS.temperature
+ )
print("Finished constructing ancestral sample ...................")
if FLAGS.dataset != "gauss":
- comb_weights_cum = []
batch_size = tf.shape(x_init)[0]
label_tiled = tf.tile(label_default, (batch_size, 1))
- e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
+ e_compute = -FLAGS.temperature * model.forward(
+ x_init, weights, label=label_tiled
+ )
e_pos_list = []
for data_corrupt, data, label_gt in tqdm(data_loader):
@@ -205,44 +252,75 @@ def main():
alr = 0.0085
elif FLAGS.dataset == "mnist":
alr = 0.0065
- #90 alr = 0.0035
+ # 90 alr = 0.0035
else:
# alr = 0.0125
if FLAGS.rescale == 8:
alr = 0.0085
else:
alr = 0.0045
-#
+ #
for i in range(1):
tot_weight = 0
- for j in tqdm(range(1, FLAGS.pdist+1)):
+ for j in tqdm(range(1, FLAGS.pdist + 1)):
if j == 1:
if FLAGS.dataset == "cifar10":
- x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
+ x_curr = np.random.uniform(
+ 0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)
+ )
elif FLAGS.dataset == "gauss":
- x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
+ x_curr = np.random.uniform(
+ 0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)
+ )
elif FLAGS.dataset == "mnist":
- x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
+ x_curr = np.random.uniform(
+ 0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)
+ )
else:
- x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
+ x_curr = np.random.uniform(
+ 0, FLAGS.rescale, size=(FLAGS.batch_size, 2)
+ )
- alpha_prev = (j-1) / FLAGS.pdist
+ alpha_prev = (j - 1) / FLAGS.pdist
alpha_new = j / FLAGS.pdist
- cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
+ cweight, x_curr = sess.run(
+ [chain_weights, x],
+ {
+ a_prev: alpha_prev,
+ a_new: alpha_new,
+ x_init: x_curr,
+ approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
+ },
+ )
tot_weight = tot_weight + cweight
- print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
+ print(
+ "Total values of lower value based off forward sampling",
+ np.mean(tot_weight),
+ np.std(tot_weight),
+ )
tot_weight = 0
for j in tqdm(range(FLAGS.pdist, 0, -1)):
- alpha_new = (j-1) / FLAGS.pdist
+ alpha_new = (j - 1) / FLAGS.pdist
alpha_prev = j / FLAGS.pdist
- cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
+ cweight, x_curr = sess.run(
+ [chain_weights, x],
+ {
+ a_prev: alpha_prev,
+ a_new: alpha_new,
+ x_init: x_curr,
+ approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
+ },
+ )
tot_weight = tot_weight - cweight
- print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
-
+ print(
+ "Total values of upper value based off backward sampling",
+ np.mean(tot_weight),
+ np.std(tot_weight),
+ )
if __name__ == "__main__":
diff --git a/EBMs/custom_adam.py b/EBMs/custom_adam.py
index 71789fe..4dc0452 100644
--- a/EBMs/custom_adam.py
+++ b/EBMs/custom_adam.py
@@ -14,223 +14,244 @@
# ==============================================================================
"""Adam for TensorFlow."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import, division, print_function
+import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.training import optimizer
-from tensorflow.python.training import training_ops
+from tensorflow.python.ops import (
+ control_flow_ops,
+ math_ops,
+ resource_variable_ops,
+ state_ops,
+)
+from tensorflow.python.training import optimizer, training_ops
from tensorflow.python.util.tf_export import tf_export
-import tensorflow as tf
@tf_export("train.AdamOptimizer")
class AdamOptimizer(optimizer.Optimizer):
- """Optimizer that implements the Adam algorithm.
+ """Optimizer that implements the Adam algorithm.
- See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
- ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
- """
-
- def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
- use_locking=False, name="Adam"):
- """Construct a new Adam optimizer.
-
- Initialization:
-
- $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
- $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
- $$t := 0 \text{(Initialize timestep)}$$
-
- The update rule for `variable` with gradient `g` uses an optimization
- described at the end of section2 of the paper:
-
- $$t := t + 1$$
- $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
-
- $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
- $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
- $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
-
- The default value of 1e-8 for epsilon might not be a good default in
- general. For example, when training an Inception network on ImageNet a
- current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
- formulation just before Section 2.1 of the Kingma and Ba paper rather than
- the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
- hat" in the paper.
-
- The sparse implementation of this algorithm (used when the gradient is an
- IndexedSlices object, typically because of `tf.gather` or an embedding
- lookup in the forward pass) does apply momentum to variable slices even if
- they were not used in the forward pass (meaning they have a gradient equal
- to zero). Momentum decay (beta1) is also applied to the entire momentum
- accumulator. This means that the sparse behavior is equivalent to the dense
- behavior (in contrast to some momentum implementations which ignore momentum
- unless a variable slice was actually used).
-
- Args:
- learning_rate: A Tensor or a floating point value. The learning rate.
- beta1: A float value or a constant float tensor.
- The exponential decay rate for the 1st moment estimates.
- beta2: A float value or a constant float tensor.
- The exponential decay rate for the 2nd moment estimates.
- epsilon: A small constant for numerical stability. This epsilon is
- "epsilon hat" in the Kingma and Ba paper (in the formula just before
- Section 2.1), not the epsilon in Algorithm 1 of the paper.
- use_locking: If True use locks for update operations.
- name: Optional name for the operations created when applying gradients.
- Defaults to "Adam".
-
- @compatibility(eager)
- When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
- `epsilon` can each be a callable that takes no arguments and returns the
- actual value to use. This can be useful for changing these values across
- different invocations of optimizer functions.
- @end_compatibility
+ See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
+ ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
- super(AdamOptimizer, self).__init__(use_locking, name)
- self._lr = learning_rate
- self._beta1 = beta1
- self._beta2 = beta2
- self._epsilon = epsilon
- # Tensor versions of the constructor arguments, created in _prepare().
- self._lr_t = None
- self._beta1_t = None
- self._beta2_t = None
- self._epsilon_t = None
+ def __init__(
+ self,
+ learning_rate=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8,
+ use_locking=False,
+ name="Adam",
+ ):
+ """Construct a new Adam optimizer.
- # Created in SparseApply if needed.
- self._updated_lr = None
+ Initialization:
- def _get_beta_accumulators(self):
- with ops.init_scope():
- if context.executing_eagerly():
- graph = None
- else:
- graph = ops.get_default_graph()
- return (self._get_non_slot_variable("beta1_power", graph=graph),
- self._get_non_slot_variable("beta2_power", graph=graph))
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
- def _create_slots(self, var_list):
- # Create the beta1 and beta2 accumulators on the same device as the first
- # variable. Sort the var_list to make sure this device is consistent across
- # workers (these need to go on the same PS, otherwise some updates are
- # silently ignored).
- first_var = min(var_list, key=lambda x: x.name)
- self._create_non_slot_variable(initial_value=self._beta1,
- name="beta1_power",
- colocate_with=first_var)
- self._create_non_slot_variable(initial_value=self._beta2,
- name="beta2_power",
- colocate_with=first_var)
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
- # Create slots for the first and second moments.
- for v in var_list:
- self._zeros_slot(v, "m", self._name)
- self._zeros_slot(v, "v", self._name)
+ $$t := t + 1$$
+ $$lr_t := \text{learning\\_rate} * \\sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
- def _prepare(self):
- lr = self._call_if_callable(self._lr)
- beta1 = self._call_if_callable(self._beta1)
- beta2 = self._call_if_callable(self._beta2)
- epsilon = self._call_if_callable(self._epsilon)
+ $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
+ $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
+ $$variable := variable - lr_t * m_t / (\\sqrt{v_t} + \\epsilon)$$
- self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
- self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
- self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
- self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
+ The default value of 1e-8 for epsilon might not be a good default in
+ general. For example, when training an Inception network on ImageNet a
+ current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
+ formulation just before Section 2.1 of the Kingma and Ba paper rather than
+ the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
+ hat" in the paper.
- def _apply_dense(self, grad, var):
- m = self.get_slot(var, "m")
- v = self.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators()
+ The sparse implementation of this algorithm (used when the gradient is an
+ IndexedSlices object, typically because of `tf.gather` or an embedding
+ lookup in the forward pass) does apply momentum to variable slices even if
+ they were not used in the forward pass (meaning they have a gradient equal
+ to zero). Momentum decay (beta1) is also applied to the entire momentum
+ accumulator. This means that the sparse behavior is equivalent to the dense
+ behavior (in contrast to some momentum implementations which ignore momentum
+ unless a variable slice was actually used).
- clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
- grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
- # Clip gradients by 3 std
- return training_ops.apply_adam(
- var, m, v,
- math_ops.cast(beta1_power, var.dtype.base_dtype),
- math_ops.cast(beta2_power, var.dtype.base_dtype),
- math_ops.cast(self._lr_t, var.dtype.base_dtype),
- math_ops.cast(self._beta1_t, var.dtype.base_dtype),
- math_ops.cast(self._beta2_t, var.dtype.base_dtype),
- math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
- grad, use_locking=self._use_locking).op
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ beta1: A float value or a constant float tensor.
+ The exponential decay rate for the 1st moment estimates.
+ beta2: A float value or a constant float tensor.
+ The exponential decay rate for the 2nd moment estimates.
+ epsilon: A small constant for numerical stability. This epsilon is
+ "epsilon hat" in the Kingma and Ba paper (in the formula just before
+ Section 2.1), not the epsilon in Algorithm 1 of the paper.
+ use_locking: If True use locks for update operations.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
- def _resource_apply_dense(self, grad, var):
- m = self.get_slot(var, "m")
- v = self.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators()
- return training_ops.resource_apply_adam(
- var.handle, m.handle, v.handle,
- math_ops.cast(beta1_power, grad.dtype.base_dtype),
- math_ops.cast(beta2_power, grad.dtype.base_dtype),
- math_ops.cast(self._lr_t, grad.dtype.base_dtype),
- math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
- math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
- math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
- grad, use_locking=self._use_locking)
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
+ `epsilon` can each be a callable that takes no arguments and returns the
+ actual value to use. This can be useful for changing these values across
+ different invocations of optimizer functions.
+ @end_compatibility
+ """
+ super(AdamOptimizer, self).__init__(use_locking, name)
+ self._lr = learning_rate
+ self._beta1 = beta1
+ self._beta2 = beta2
+ self._epsilon = epsilon
- def _apply_sparse_shared(self, grad, var, indices, scatter_add):
- beta1_power, beta2_power = self._get_beta_accumulators()
- beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
- beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
- lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
- beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
- beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
- epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
- lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
- # m_t = beta1 * m + (1 - beta1) * g_t
- m = self.get_slot(var, "m")
- m_scaled_g_values = grad * (1 - beta1_t)
- m_t = state_ops.assign(m, m * beta1_t,
- use_locking=self._use_locking)
- with ops.control_dependencies([m_t]):
- m_t = scatter_add(m, indices, m_scaled_g_values)
- # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
- v = self.get_slot(var, "v")
- v_scaled_g_values = (grad * grad) * (1 - beta2_t)
- v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
- with ops.control_dependencies([v_t]):
- v_t = scatter_add(v, indices, v_scaled_g_values)
- v_sqrt = math_ops.sqrt(v_t)
- var_update = state_ops.assign_sub(var,
- lr * m_t / (v_sqrt + epsilon_t),
- use_locking=self._use_locking)
- return control_flow_ops.group(*[var_update, m_t, v_t])
+ # Tensor versions of the constructor arguments, created in _prepare().
+ self._lr_t = None
+ self._beta1_t = None
+ self._beta2_t = None
+ self._epsilon_t = None
- def _apply_sparse(self, grad, var):
- return self._apply_sparse_shared(
- grad.values, var, grad.indices,
- lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
- x, i, v, use_locking=self._use_locking))
+ # Created in SparseApply if needed.
+ self._updated_lr = None
- def _resource_scatter_add(self, x, i, v):
- with ops.control_dependencies(
- [resource_variable_ops.resource_scatter_add(
- x.handle, i, v)]):
- return x.value()
+ def _get_beta_accumulators(self):
+ with ops.init_scope():
+ if context.executing_eagerly():
+ graph = None
+ else:
+ graph = ops.get_default_graph()
+ return (
+ self._get_non_slot_variable("beta1_power", graph=graph),
+ self._get_non_slot_variable("beta2_power", graph=graph),
+ )
- def _resource_apply_sparse(self, grad, var, indices):
- return self._apply_sparse_shared(
- grad, var, indices, self._resource_scatter_add)
+ def _create_slots(self, var_list):
+ # Create the beta1 and beta2 accumulators on the same device as the first
+ # variable. Sort the var_list to make sure this device is consistent across
+ # workers (these need to go on the same PS, otherwise some updates are
+ # silently ignored).
+ first_var = min(var_list, key=lambda x: x.name)
+ self._create_non_slot_variable(
+ initial_value=self._beta1, name="beta1_power", colocate_with=first_var
+ )
+ self._create_non_slot_variable(
+ initial_value=self._beta2, name="beta2_power", colocate_with=first_var
+ )
- def _finish(self, update_ops, name_scope):
- # Update the power accumulators.
- with ops.control_dependencies(update_ops):
- beta1_power, beta2_power = self._get_beta_accumulators()
- with ops.colocate_with(beta1_power):
- update_beta1 = beta1_power.assign(
- beta1_power * self._beta1_t, use_locking=self._use_locking)
- update_beta2 = beta2_power.assign(
- beta2_power * self._beta2_t, use_locking=self._use_locking)
- return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
- name=name_scope)
+ # Create slots for the first and second moments.
+ for v in var_list:
+ self._zeros_slot(v, "m", self._name)
+ self._zeros_slot(v, "v", self._name)
+
+ def _prepare(self):
+ lr = self._call_if_callable(self._lr)
+ beta1 = self._call_if_callable(self._beta1)
+ beta2 = self._call_if_callable(self._beta2)
+ epsilon = self._call_if_callable(self._epsilon)
+
+ self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
+ self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
+ self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
+ self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
+
+ def _apply_dense(self, grad, var):
+ m = self.get_slot(var, "m")
+ v = self.get_slot(var, "v")
+ beta1_power, beta2_power = self._get_beta_accumulators()
+
+ clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
+ grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
+ # Clip gradients by 3 std
+ return training_ops.apply_adam(
+ var,
+ m,
+ v,
+ math_ops.cast(beta1_power, var.dtype.base_dtype),
+ math_ops.cast(beta2_power, var.dtype.base_dtype),
+ math_ops.cast(self._lr_t, var.dtype.base_dtype),
+ math_ops.cast(self._beta1_t, var.dtype.base_dtype),
+ math_ops.cast(self._beta2_t, var.dtype.base_dtype),
+ math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking,
+ ).op
+
+ def _resource_apply_dense(self, grad, var):
+ m = self.get_slot(var, "m")
+ v = self.get_slot(var, "v")
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ return training_ops.resource_apply_adam(
+ var.handle,
+ m.handle,
+ v.handle,
+ math_ops.cast(beta1_power, grad.dtype.base_dtype),
+ math_ops.cast(beta2_power, grad.dtype.base_dtype),
+ math_ops.cast(self._lr_t, grad.dtype.base_dtype),
+ math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
+ math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
+ math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking,
+ )
+
+ def _apply_sparse_shared(self, grad, var, indices, scatter_add):
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
+ lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+ beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+ beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+ epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+ lr = lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)
+ # m_t = beta1 * m + (1 - beta1) * g_t
+ m = self.get_slot(var, "m")
+ m_scaled_g_values = grad * (1 - beta1_t)
+ m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
+ with ops.control_dependencies([m_t]):
+ m_t = scatter_add(m, indices, m_scaled_g_values)
+ # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
+ v = self.get_slot(var, "v")
+ v_scaled_g_values = (grad * grad) * (1 - beta2_t)
+ v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
+ with ops.control_dependencies([v_t]):
+ v_t = scatter_add(v, indices, v_scaled_g_values)
+ v_sqrt = math_ops.sqrt(v_t)
+ var_update = state_ops.assign_sub(
+ var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking
+ )
+ return control_flow_ops.group(*[var_update, m_t, v_t])
+
+ def _apply_sparse(self, grad, var):
+ return self._apply_sparse_shared(
+ grad.values,
+ var,
+ grad.indices,
+ lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
+ x, i, v, use_locking=self._use_locking
+ ),
+ )
+
+ def _resource_scatter_add(self, x, i, v):
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(x.handle, i, v)]
+ ):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add)
+
+ def _finish(self, update_ops, name_scope):
+ # Update the power accumulators.
+ with ops.control_dependencies(update_ops):
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ with ops.colocate_with(beta1_power):
+ update_beta1 = beta1_power.assign(
+ beta1_power * self._beta1_t, use_locking=self._use_locking
+ )
+ update_beta2 = beta2_power.assign(
+ beta2_power * self._beta2_t, use_locking=self._use_locking
+ )
+ return control_flow_ops.group(
+ *update_ops + [update_beta1, update_beta2], name=name_scope
+ )
diff --git a/EBMs/data.py b/EBMs/data.py
index 42d93b3..23af3e7 100644
--- a/EBMs/data.py
+++ b/EBMs/data.py
@@ -1,42 +1,48 @@
-from tensorflow.python.platform import flags
-from tensorflow.contrib.data.python.ops import batching, threadpool
-import tensorflow as tf
import json
-from torch.utils.data import Dataset
-import pickle
-import os.path as osp
import os
-import numpy as np
+import os.path as osp
+import pickle
import time
-from scipy.misc import imread, imresize
-from skimage.color import rgb2grey
-from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
-from torchvision import transforms
-from imagenet_preprocessing import ImagenetPreprocessor
+
+import numpy as np
+import tensorflow as tf
import torch
import torchvision
+from imagenet_preprocessing import ImagenetPreprocessor
+from scipy.misc import imread, imresize
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.platform import flags
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, ImageFolder
FLAGS = flags.FLAGS
ROOT_DIR = "./results"
# Dataset Options
-flags.DEFINE_string('dsprites_path',
- '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
- 'path to dsprites characters')
-flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
-flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
-flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
-flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
-flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
-flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
-flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
+flags.DEFINE_string(
+ "dsprites_path",
+ "/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
+ "path to dsprites characters",
+)
+flags.DEFINE_string(
+ "imagenet_datadir", "/root/imagenet_big", "whether cutoff should always in image"
+)
+flags.DEFINE_bool("dshape_only", False, "fix all factors except for shapes")
+flags.DEFINE_bool("dpos_only", False, "fix all factors except for positions of shapes")
+flags.DEFINE_bool("dsize_only", False, "fix all factors except for size of objects")
+flags.DEFINE_bool("drot_only", False, "fix all factors except for rotation of objects")
+flags.DEFINE_bool(
+ "dsprites_restrict", False, "fix all factors except for rotation of objects"
+)
+flags.DEFINE_string("imagenet_path", "/root/imagenet", "path to imagenet images")
# Data augmentation options
-flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
-flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
-flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
-flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
+flags.DEFINE_bool("cutout_inside", False, "whether cutoff should always in image")
+flags.DEFINE_float("cutout_prob", 1.0, "probability of using cutout")
+flags.DEFINE_integer("cutout_mask_size", 16, "size of cutout")
+flags.DEFINE_bool("cutout", False, "whether to add cutout regularizer to data")
def cutout(mask_color=(0, 0, 0)):
@@ -91,13 +97,15 @@ class TFImagenetLoader(Dataset):
self.curr_sample = 0
- index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
+ index_path = osp.join(FLAGS.imagenet_datadir, "index.json")
with open(index_path) as f:
metadata = json.load(f)
- counts = metadata['record_counts']
+ counts = metadata["record_counts"]
- if split == 'train':
- file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
+ if split == "train":
+ file_names = list(
+ sorted([x for x in counts.keys() if x.startswith("train")])
+ )
result_records_to_skip = None
files = []
@@ -111,30 +119,44 @@ class TFImagenetLoader(Dataset):
# Record the number to skip in the first file
result_records_to_skip = records_to_skip
files.append(filename)
- records_to_read -= (records_in_file - records_to_skip)
+ records_to_read -= records_in_file - records_to_skip
records_to_skip = 0
else:
break
else:
- files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
+ files = list(
+ sorted([x for x in counts.keys() if x.startswith("validation")])
+ )
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
- preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
+ preprocess_function = ImagenetPreprocessor(
+ 128, dtype=tf.float32, train=False
+ ).parse_and_preprocess
- ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
+ ds = tf.data.TFRecordDataset.from_generator(
+ lambda: files, output_types=tf.string
+ )
ds = ds.apply(tf.data.TFRecordDataset)
ds = ds.take(im_length)
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
- ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
+ ds = ds.apply(
+ batching.map_and_batch(
+ map_func=preprocess_function,
+ batch_size=FLAGS.batch_size,
+ num_parallel_batches=4,
+ )
+ )
ds = ds.prefetch(buffer_size=2)
ds_iterator = ds.make_initializable_iterator()
labels, images = ds_iterator.get_next()
- self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
+ self.images = tf.clip_by_value(
+ images / 256 + tf.random_uniform(tf.shape(images), 0, 1.0 / 256), 0.0, 1.0
+ )
self.labels = labels
- config = tf.ConfigProto(device_count = {'GPU': 0})
+ config = tf.ConfigProto(device_count={"GPU": 0})
sess = tf.Session(config=config)
sess.run(ds_iterator.initializer)
@@ -147,11 +169,17 @@ class TFImagenetLoader(Dataset):
sess = self.sess
- im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
+ im_corrupt = np.random.uniform(
+ 0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)
+ )
label, im = sess.run([self.labels, self.images])
im = im * self.rescale
label = np.eye(1000)[label.squeeze() - 1]
- im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
+ im, im_corrupt, label = (
+ torch.from_numpy(im),
+ torch.from_numpy(im_corrupt),
+ torch.from_numpy(label),
+ )
return im_corrupt, im, label
def __iter__(self):
@@ -160,6 +188,7 @@ class TFImagenetLoader(Dataset):
def __len__(self):
return self.im_length
+
class CelebA(Dataset):
def __init__(self):
@@ -180,25 +209,18 @@ class CelebA(Dataset):
im = imread(path)
im = imresize(im, (32, 32))
image_size = 32
- im = im / 255.
+ im = im / 255.0
- if FLAGS.datasource == 'default':
+ if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
- elif FLAGS.datasource == 'random':
- im_corrupt = np.random.uniform(
- 0, 1, size=(image_size, image_size, 3))
+ elif FLAGS.datasource == "random":
+ im_corrupt = np.random.uniform(0, 1, size=(image_size, image_size, 3))
return im_corrupt, im, label
class Cifar10(Dataset):
- def __init__(
- self,
- train=True,
- full=False,
- augment=False,
- noise=True,
- rescale=1.0):
+ def __init__(self, train=True, full=False, augment=False, noise=True, rescale=1.0):
if augment:
transform_list = [
@@ -215,16 +237,10 @@ class Cifar10(Dataset):
transform = transforms.ToTensor()
self.full = full
- self.data = CIFAR10(
- ROOT_DIR,
- transform=transform,
- train=train,
- download=True)
+ self.data = CIFAR10(ROOT_DIR, transform=transform, train=train, download=True)
self.test_data = CIFAR10(
- ROOT_DIR,
- transform=transform,
- train=False,
- download=True)
+ ROOT_DIR, transform=transform, train=False, download=True
+ )
self.one_hot_map = np.eye(10)
self.noise = noise
self.rescale = rescale
@@ -255,16 +271,18 @@ class Cifar10(Dataset):
im = im * 255 / 256
if self.noise:
- im = im * self.rescale + \
- np.random.uniform(0, self.rescale * 1 / 256., im.shape)
+ im = im * self.rescale + np.random.uniform(
+ 0, self.rescale * 1 / 256.0, im.shape
+ )
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
- if FLAGS.datasource == 'default':
+ if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
- elif FLAGS.datasource == 'random':
+ elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(
- 0.0, self.rescale, (image_size, image_size, 3))
+ 0.0, self.rescale, (image_size, image_size, 3)
+ )
return im_corrupt, im, label
@@ -287,10 +305,8 @@ class Cifar100(Dataset):
transform = transforms.ToTensor()
self.data = CIFAR100(
- "/root/cifar100",
- transform=transform,
- train=train,
- download=True)
+ "/root/cifar100", transform=transform, train=train, download=True
+ )
self.one_hot_map = np.eye(100)
def __len__(self):
@@ -308,11 +324,10 @@ class Cifar100(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
- if FLAGS.datasource == 'default':
+ if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
- elif FLAGS.datasource == 'random':
- im_corrupt = np.random.uniform(
- 0.0, 1.0, (image_size, image_size, 3))
+ elif FLAGS.datasource == "random":
+ im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
@@ -340,11 +355,10 @@ class Svhn(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
- if FLAGS.datasource == 'default':
+ if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
- elif FLAGS.datasource == 'random':
- im_corrupt = np.random.uniform(
- 0.0, 1.0, (image_size, image_size, 3))
+ elif FLAGS.datasource == "random":
+ im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
@@ -352,9 +366,8 @@ class Svhn(Dataset):
class Mnist(Dataset):
def __init__(self, train=True, rescale=1.0):
self.data = MNIST(
- "/root/mnist",
- transform=transforms.ToTensor(),
- download=True, train=train)
+ "/root/mnist", transform=transforms.ToTensor(), download=True, train=train
+ )
self.labels = np.eye(10)
self.rescale = rescale
@@ -367,13 +380,13 @@ class Mnist(Dataset):
im = im.squeeze()
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
# im = im.numpy() / 2 + 0.2
- im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
+ im = im.numpy() / 256 * 255 + np.random.uniform(0, 1.0 / 256, (28, 28))
im = im * self.rescale
image_size = 28
- if FLAGS.datasource == 'default':
+ if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
- elif FLAGS.datasource == 'random':
+ elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
return im_corrupt, im, label
@@ -381,54 +394,63 @@ class Mnist(Dataset):
class DSprites(Dataset):
def __init__(
- self,
- cond_size=False,
- cond_shape=False,
- cond_pos=False,
- cond_rot=False):
+ self, cond_size=False, cond_shape=False, cond_pos=False, cond_rot=False
+ ):
dat = np.load(FLAGS.dsprites_path)
if FLAGS.dshape_only:
- l = dat['latents_values']
- mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
- 31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
- self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
- self.label = np.tile(dat['latents_values'][mask], (10000, 1))
+ l = dat["latents_values"]
+ mask = (
+ (l[:, 4] == 16 / 31)
+ & (l[:, 5] == 16 / 31)
+ & (l[:, 2] == 0.5)
+ & (l[:, 3] == 30 * np.pi / 39)
+ )
+ self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
+ self.label = np.tile(dat["latents_values"][mask], (10000, 1))
self.label = self.label[:, 1:2]
elif FLAGS.dpos_only:
- l = dat['latents_values']
+ l = dat["latents_values"]
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
- mask = (l[:, 1] == 1) & (
- l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
- self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
- self.label = np.tile(dat['latents_values'][mask], (100, 1))
+ mask = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
+ self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
+ self.label = np.tile(dat["latents_values"][mask], (100, 1))
self.label = self.label[:, 4:] + 0.5
elif FLAGS.dsize_only:
- l = dat['latents_values']
+ l = dat["latents_values"]
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
- mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
- 31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
- self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
- self.label = np.tile(dat['latents_values'][mask], (10000, 1))
- self.label = (self.label[:, 2:3])
+ mask = (
+ (l[:, 3] == 30 * np.pi / 39)
+ & (l[:, 4] == 16 / 31)
+ & (l[:, 5] == 16 / 31)
+ & (l[:, 1] == 1)
+ )
+ self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
+ self.label = np.tile(dat["latents_values"][mask], (10000, 1))
+ self.label = self.label[:, 2:3]
elif FLAGS.drot_only:
- l = dat['latents_values']
- mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
- 31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
- self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
- self.label = np.tile(dat['latents_values'][mask], (100, 1))
- self.label = (self.label[:, 3:4])
+ l = dat["latents_values"]
+ mask = (
+ (l[:, 2] == 0.5)
+ & (l[:, 4] == 16 / 31)
+ & (l[:, 5] == 16 / 31)
+ & (l[:, 1] == 1)
+ )
+ self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
+ self.label = np.tile(dat["latents_values"][mask], (100, 1))
+ self.label = self.label[:, 3:4]
self.label = np.concatenate(
- [np.cos(self.label), np.sin(self.label)], axis=1)
+ [np.cos(self.label), np.sin(self.label)], axis=1
+ )
elif FLAGS.dsprites_restrict:
- l = dat['latents_values']
+ l = dat["latents_values"]
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
- self.data = dat['imgs'][mask]
- self.label = dat['latents_values'][mask]
+ self.data = dat["imgs"][mask]
+ self.label = dat["latents_values"][mask]
else:
- self.data = dat['imgs']
- self.label = dat['latents_values']
+ self.data = dat["imgs"]
+ self.label = dat["latents_values"]
if cond_size:
self.label = self.label[:, 2:3]
@@ -439,7 +461,8 @@ class DSprites(Dataset):
elif cond_rot:
self.label = self.label[:, 3:4]
self.label = np.concatenate(
- [np.cos(self.label), np.sin(self.label)], axis=1)
+ [np.cos(self.label), np.sin(self.label)], axis=1
+ )
else:
self.label = self.label[:, 1:2]
@@ -452,20 +475,20 @@ class DSprites(Dataset):
im = self.data[index]
image_size = 64
- if not (
- FLAGS.dpos_only or FLAGS.dsize_only) and (
- not FLAGS.cond_size) and (
- not FLAGS.cond_pos) and (
- not FLAGS.cond_rot) and (
- not FLAGS.drot_only):
- label = self.identity[self.label[index].astype(
- np.int32) - 1].squeeze()
+ if (
+ not (FLAGS.dpos_only or FLAGS.dsize_only)
+ and (not FLAGS.cond_size)
+ and (not FLAGS.cond_pos)
+ and (not FLAGS.cond_rot)
+ and (not FLAGS.drot_only)
+ ):
+ label = self.identity[self.label[index].astype(np.int32) - 1].squeeze()
else:
label = self.label[index]
- if FLAGS.datasource == 'default':
+ if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
- elif FLAGS.datasource == 'random':
+ elif FLAGS.datasource == "random":
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
return im_corrupt, im, label
@@ -478,25 +501,20 @@ class Imagenet(Dataset):
for i in range(1, 11):
f = pickle.load(
open(
- osp.join(
- FLAGS.imagenet_path,
- 'train_data_batch_{}'.format(i)),
- 'rb'))
+ osp.join(FLAGS.imagenet_path, "train_data_batch_{}".format(i)),
+ "rb",
+ )
+ )
if i == 1:
- labels = f['labels']
- data = f['data']
+ labels = f["labels"]
+ data = f["data"]
else:
- labels.extend(f['labels'])
- data = np.vstack((data, f['data']))
+ labels.extend(f["labels"])
+ data = np.vstack((data, f["data"]))
else:
- f = pickle.load(
- open(
- osp.join(
- FLAGS.imagenet_path,
- 'val_data'),
- 'rb'))
- labels = f['labels']
- data = f['data']
+ f = pickle.load(open(osp.join(FLAGS.imagenet_path, "val_data"), "rb"))
+ labels = f["labels"]
+ data = f["data"]
self.labels = labels
self.data = data
@@ -520,11 +538,10 @@ class Imagenet(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
- if FLAGS.datasource == 'default':
+ if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
- elif FLAGS.datasource == 'random':
- im_corrupt = np.random.uniform(
- 0.0, 1.0, (image_size, image_size, 3))
+ elif FLAGS.datasource == "random":
+ im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
diff --git a/EBMs/ebm_combine.py b/EBMs/ebm_combine.py
index c79a606..ca13a3b 100644
--- a/EBMs/ebm_combine.py
+++ b/EBMs/ebm_combine.py
@@ -1,68 +1,100 @@
+import os
+import os.path as osp
+
+import numpy as np
import tensorflow as tf
-import math
-from tqdm import tqdm
-from hmc import hmc
+from custom_adam import AdamOptimizer
+from models import DspritesNet
+from scipy.misc import imsave
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader, Dataset
-from models import DspritesNet
-from utils import optimistic_restore, ReplayBuffer
-import os.path as osp
-import numpy as np
-from rl_algs.logger import TensorBoardOutputFormat
-from scipy.misc import imsave
-import os
-from custom_adam import AdamOptimizer
+from tqdm import tqdm
+from utils import ReplayBuffer
-flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
-flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
-flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
-flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
-flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
-flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
-flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
-flags.DEFINE_bool('cclass', True, 'not cclass')
-flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons')
-flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
-flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
-flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
-flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results')
-flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
-flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.')
-flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape')
-flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape')
+flags.DEFINE_integer("batch_size", 256, "Size of inputs")
+flags.DEFINE_integer("data_workers", 4, "Number of workers to do things")
+flags.DEFINE_string("logdir", "cachedir", "directory for logging")
+flags.DEFINE_string(
+ "savedir", "cachedir", "location where log of experiments will be stored"
+)
+flags.DEFINE_integer(
+ "num_filters",
+ 64,
+ "number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
+)
+flags.DEFINE_float("step_lr", 500, "size of gradient descent size")
+flags.DEFINE_string(
+ "dsprites_path",
+ "/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
+ "path to dsprites characters",
+)
+flags.DEFINE_bool("cclass", True, "not cclass")
+flags.DEFINE_bool("proj_cclass", False, "use for backwards compatibility reasons")
+flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
+flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
+flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
+flags.DEFINE_bool("plot_curve", False, "Generate a curve of results")
+flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
+flags.DEFINE_string(
+ "task",
+ "conceptcombine",
+ "conceptcombine, labeldiscover, gentest, genbaseline, etc.",
+)
+flags.DEFINE_bool("joint_shape", False, "whether to use pos_size or pos_shape")
+flags.DEFINE_bool("joint_rot", False, "whether to use pos_size or pos_shape")
# Conditions on which models to use
-flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
-flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
-flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
-flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale')
+flags.DEFINE_bool("cond_pos", True, "whether to condition on position")
+flags.DEFINE_bool("cond_rot", True, "whether to condition on rotation")
+flags.DEFINE_bool("cond_shape", True, "whether to condition on shape")
+flags.DEFINE_bool("cond_scale", True, "whether to condition on scale")
-flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments')
-flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments')
-flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments')
-flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments')
-flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume')
-flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume')
-flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume')
-flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume')
-flags.DEFINE_integer('break_steps', 300, 'steps to break')
+flags.DEFINE_string("exp_size", "dsprites_2018_cond_size", "name of experiments")
+flags.DEFINE_string("exp_shape", "dsprites_2018_cond_shape", "name of experiments")
+flags.DEFINE_string("exp_pos", "dsprites_2018_cond_pos_cert", "name of experiments")
+flags.DEFINE_string("exp_rot", "dsprites_cond_rot_119_00", "name of experiments")
+flags.DEFINE_integer("resume_size", 169000, "First iteration to resume")
+flags.DEFINE_integer("resume_shape", 477000, "Second iteration to resume")
+flags.DEFINE_integer("resume_pos", 8000, "Second iteration to resume")
+flags.DEFINE_integer("resume_rot", 690000, "Second iteration to resume")
+flags.DEFINE_integer("break_steps", 300, "steps to break")
# Whether to train for gentest
-flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions')
+flags.DEFINE_bool(
+ "train",
+ False,
+ "whether to train on generalization into multiple different predictions",
+)
FLAGS = flags.FLAGS
+
class DSpritesGen(Dataset):
def __init__(self, data, latents, frac=0.0):
l = latents
if FLAGS.joint_shape:
- mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
+ mask_size = (
+ (l[:, 3] == 30 * np.pi / 39)
+ & (l[:, 4] == 16 / 31)
+ & (l[:, 5] == 16 / 31)
+ & (l[:, 2] == 0.5)
+ )
elif FLAGS.joint_rot:
- mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
+ mask_size = (
+ (l[:, 1] == 1)
+ & (l[:, 4] == 16 / 31)
+ & (l[:, 5] == 16 / 31)
+ & (l[:, 2] == 0.5)
+ )
else:
- mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1)
+ mask_size = (
+ (l[:, 3] == 30 * np.pi / 39)
+ & (l[:, 4] == 16 / 31)
+ & (l[:, 5] == 16 / 31)
+ & (l[:, 1] == 1)
+ )
mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
@@ -80,12 +112,14 @@ class DSpritesGen(Dataset):
self.data = np.concatenate((data_pos, data_size), axis=0)
self.label = np.concatenate((l_pos, l_size), axis=0)
- mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39))
+ mask_neg = (~(mask_size & mask_pos)) & (
+ (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)
+ )
data_add = data[mask_neg]
l_add = l[mask_neg]
perm_idx = np.random.permutation(data_add.shape[0])
- select_idx = perm_idx[:int(frac*perm_idx.shape[0])]
+ select_idx = perm_idx[: int(frac * perm_idx.shape[0])]
data_add = data_add[select_idx]
l_add = l_add[select_idx]
@@ -104,7 +138,9 @@ class DSpritesGen(Dataset):
if FLAGS.joint_shape:
label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
elif FLAGS.joint_rot:
- label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])])
+ label_size = np.array(
+ [np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]
+ )
else:
label_size = self.label[index, 2:3]
@@ -114,14 +150,16 @@ class DSpritesGen(Dataset):
def labeldiscover(sess, kvs, data, latents, save_exp_dir):
- LABEL_SIZE = kvs['LABEL_SIZE']
- model_size = kvs['model_size']
- weight_size = kvs['weight_size']
- x_mod = kvs['X_NOISE']
+ LABEL_SIZE = kvs["LABEL_SIZE"]
+ model_size = kvs["model_size"]
+ weight_size = kvs["weight_size"]
+ x_mod = kvs["X_NOISE"]
label_output = LABEL_SIZE
for i in range(FLAGS.num_steps):
- label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03)
+ label_output = label_output + tf.random_normal(
+ tf.shape(label_output), mean=0.0, stddev=0.03
+ )
e_noise = model_size.forward(x_mod, weight_size, label=label_output)
label_grad = tf.gradients(e_noise, [label_output])[0]
# label_grad = tf.Print(label_grad, [label_grad])
@@ -130,13 +168,13 @@ def labeldiscover(sess, kvs, data, latents, save_exp_dir):
diffs = []
for i in range(30):
- s = i*FLAGS.batch_size
- d = (i+1)*FLAGS.batch_size
+ s = i * FLAGS.batch_size
+ d = (i + 1) * FLAGS.batch_size
data_i = data[s:d]
latent_i = latents[s:d]
latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
- feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init}
+ feed_dict = {x_mod: data_i, LABEL_SIZE: latent_init}
size_pred = sess.run([label_output], feed_dict)[0]
size_gt = latent_i[:, 2:3]
@@ -155,9 +193,11 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
- weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))
+ weights_baseline = model_baseline.construct_weights(
+ "context_baseline_{}".format(frac)
+ )
- X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
+ X_feed = tf.placeholder(shape=(None, 2 * FLAGS.num_filters), dtype=tf.float32)
X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
@@ -168,14 +208,20 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
gvs = [(k, v) for (k, v) in gvs if k is not None]
train_op = optimizer.apply_gradients(gvs)
- dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
+ dataloader = DataLoader(
+ DSpritesGen(data, latents, frac=frac),
+ batch_size=FLAGS.batch_size,
+ num_workers=6,
+ drop_last=True,
+ shuffle=True,
+ )
datafull = data
itr = 0
saver = tf.train.Saver()
- vs = optimizer.variables()
+ optimizer.variables()
sess.run(tf.global_variables_initializer())
if FLAGS.train:
@@ -185,7 +231,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
data_corrupt = data_corrupt.numpy()
label_size, label_pos = label_size.numpy(), label_pos.numpy()
- data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
+ data_corrupt = np.random.randn(
+ data_corrupt.shape[0], 2 * FLAGS.num_filters
+ )
label_comb = np.concatenate([label_size, label_pos], axis=1)
feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
@@ -196,23 +244,27 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
itr += 1
- saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))
+ saver.save(sess, osp.join(save_exp_dir, "model_genbaseline"))
- saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))
+ saver.restore(sess, osp.join(save_exp_dir, "model_genbaseline"))
l = latents
if FLAGS.joint_shape:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
else:
- mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
+ mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
+ ~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
+ )
data_gen = datafull[mask_gen]
latents_gen = latents[mask_gen]
losses = []
- for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
- data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)
+ for dat, latent in zip(
+ np.array_split(data_gen, 10), np.array_split(latents_gen, 10)
+ ):
+ data_init = np.random.randn(dat.shape[0], 2 * FLAGS.num_filters)
if FLAGS.joint_shape:
latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
@@ -220,16 +272,19 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
latent = np.concatenate([latent_size, latent_pos], axis=1)
feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
else:
- feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
+ feed_dict = {X_feed: data_init, LABEL: latent[:, [2, 4, 5]], X_label: dat}
loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
# print(loss)
losses.append(loss)
- print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))
-
+ print(
+ "Overall MSE for generalization of {} for fraction of {}".format(
+ np.mean(losses), frac
+ )
+ )
data_try = data_gen[:10]
- data_init = np.random.randn(10, 2*FLAGS.num_filters)
+ data_init = np.random.randn(10, 2 * FLAGS.num_filters)
if FLAGS.joint_shape:
latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
@@ -252,7 +307,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
x_output_wrap[:, 1:-1, 1:-1] = x_output
data_try_wrap[:, 1:-1, 1:-1] = data_try
- im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
+ im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
+ -1, 66 * 2
+ )
impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output)
print("Successfully saved images at {}".format(impath))
@@ -261,38 +318,40 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
def gentest(sess, kvs, data, latents, save_exp_dir):
- X_NOISE = kvs['X_NOISE']
- LABEL_SIZE = kvs['LABEL_SIZE']
- LABEL_SHAPE = kvs['LABEL_SHAPE']
- LABEL_POS = kvs['LABEL_POS']
- LABEL_ROT = kvs['LABEL_ROT']
- model_size = kvs['model_size']
- model_shape = kvs['model_shape']
- model_pos = kvs['model_pos']
- model_rot = kvs['model_rot']
- weight_size = kvs['weight_size']
- weight_shape = kvs['weight_shape']
- weight_pos = kvs['weight_pos']
- weight_rot = kvs['weight_rot']
+ X_NOISE = kvs["X_NOISE"]
+ LABEL_SIZE = kvs["LABEL_SIZE"]
+ LABEL_SHAPE = kvs["LABEL_SHAPE"]
+ LABEL_POS = kvs["LABEL_POS"]
+ LABEL_ROT = kvs["LABEL_ROT"]
+ model_size = kvs["model_size"]
+ model_shape = kvs["model_shape"]
+ model_pos = kvs["model_pos"]
+ model_rot = kvs["model_rot"]
+ weight_size = kvs["weight_size"]
+ weight_shape = kvs["weight_shape"]
+ weight_pos = kvs["weight_pos"]
+ weight_rot = kvs["weight_rot"]
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
datafull = data
# Test combination of generalization where we use slices of both training
x_final = X_NOISE
- x_mod_size = X_NOISE
x_mod_pos = X_NOISE
for i in range(FLAGS.num_steps):
# use cond_pos
- energies = []
- x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
+ x_mod_pos = x_mod_pos + tf.random_normal(
+ tf.shape(x_mod_pos), mean=0.0, stddev=0.005
+ )
e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
# energies.append(e_noise)
x_grad = tf.gradients(e_noise, [x_final])[0]
- x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
+ x_mod_pos = x_mod_pos + tf.random_normal(
+ tf.shape(x_mod_pos), mean=0.0, stddev=0.005
+ )
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
@@ -332,42 +391,70 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
x_mod = x_mod_pos
x_final = x_mod
-
if FLAGS.joint_shape:
- loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
- model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
+ loss_kl = model_shape.forward(
+ x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True
+ ) + model_pos.forward(
+ x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
+ )
- energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
- model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
+ energy_pos = model_shape.forward(
+ X, weight_shape, reuse=True, label=LABEL_SHAPE
+ ) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
- energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
- model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
+ energy_neg = model_shape.forward(
+ tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE
+ ) + model_pos.forward(
+ tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
+ )
elif FLAGS.joint_rot:
- loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
- model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
+ loss_kl = model_rot.forward(
+ x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True
+ ) + model_pos.forward(
+ x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
+ )
- energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
- model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
+ energy_pos = model_rot.forward(
+ X, weight_rot, reuse=True, label=LABEL_ROT
+ ) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
- energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
- model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
+ energy_neg = model_rot.forward(
+ tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT
+ ) + model_pos.forward(
+ tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
+ )
else:
- loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
- model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
+ loss_kl = model_size.forward(
+ x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True
+ ) + model_pos.forward(
+ x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
+ )
- energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
- model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
+ energy_pos = model_size.forward(
+ X, weight_size, reuse=True, label=LABEL_SIZE
+ ) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
- energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
- model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
+ energy_neg = model_size.forward(
+ tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE
+ ) + model_pos.forward(
+ tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
+ )
- energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
+ energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg)
coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
- neg_loss = coeff * (-1*energy_neg) / norm_constant
+ coeff * (-1 * energy_neg) / norm_constant
loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
- loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
+ loss_total = (
+ loss_ml
+ + tf.reduce_mean(loss_kl)
+ + 1
+ * (
+ tf.reduce_mean(tf.square(energy_pos))
+ + tf.reduce_mean(tf.square(energy_neg))
+ )
+ )
optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
gvs = optimizer.compute_gradients(loss_total)
@@ -377,7 +464,13 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
vs = optimizer.variables()
sess.run(tf.variables_initializer(vs))
- dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
+ dataloader = DataLoader(
+ DSpritesGen(data, latents),
+ batch_size=FLAGS.batch_size,
+ num_workers=6,
+ drop_last=True,
+ shuffle=True,
+ )
x_off = tf.reduce_mean(tf.square(x_mod - X))
@@ -385,12 +478,10 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
saver = tf.train.Saver()
x_mod = None
-
if FLAGS.train:
replay_buffer = ReplayBuffer(10000)
for _ in range(1):
-
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
data_corrupt = data_corrupt.numpy()[:, :, :]
data = data.numpy()[:, :, :]
@@ -398,29 +489,50 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
if x_mod is not None:
replay_buffer.add(x_mod)
replay_batch = replay_buffer.sample(FLAGS.batch_size)
- replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
+ replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95
data_corrupt[replay_mask] = replay_batch[replay_mask]
if FLAGS.joint_shape:
- feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
+ feed_dict = {
+ X_NOISE: data_corrupt,
+ X: data,
+ LABEL_SHAPE: label_size,
+ LABEL_POS: label_pos,
+ }
elif FLAGS.joint_rot:
- feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
+ feed_dict = {
+ X_NOISE: data_corrupt,
+ X: data,
+ LABEL_ROT: label_size,
+ LABEL_POS: label_pos,
+ }
else:
- feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}
+ feed_dict = {
+ X_NOISE: data_corrupt,
+ X: data,
+ LABEL_SIZE: label_size,
+ LABEL_POS: label_pos,
+ }
- _, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
+ _, off_value, e_pos, e_neg, x_mod = sess.run(
+ [train_op, x_off, energy_pos, energy_neg, x_final],
+ feed_dict=feed_dict,
+ )
itr += 1
if itr % 10 == 0:
- print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))
+ print(
+ "x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(
+ off_value, e_pos.mean(), e_neg.mean(), itr
+ )
+ )
if itr == FLAGS.break_steps:
break
+ saver.save(sess, osp.join(save_exp_dir, "model_gentest"))
- saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))
-
- saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))
+ saver.restore(sess, osp.join(save_exp_dir, "model_gentest"))
l = latents
@@ -429,22 +541,43 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
elif FLAGS.joint_rot:
mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
else:
- mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
+ mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
+ ~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
+ )
data_gen = datafull[mask_gen]
latents_gen = latents[mask_gen]
losses = []
- for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
+ for dat, latent in zip(
+ np.array_split(data_gen, 120), np.array_split(latents_gen, 120)
+ ):
x = 0.5 + np.random.randn(*dat.shape)
if FLAGS.joint_shape:
- feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
+ feed_dict = {
+ LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1],
+ LABEL_POS: latent[:, 4:],
+ X_NOISE: x,
+ X: dat,
+ }
elif FLAGS.joint_rot:
- feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
+ feed_dict = {
+ LABEL_ROT: np.concatenate(
+ [np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1
+ ),
+ LABEL_POS: latent[:, 4:],
+ X_NOISE: x,
+ X: dat,
+ }
else:
- feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
+ feed_dict = {
+ LABEL_SIZE: latent[:, 2:3],
+ LABEL_POS: latent[:, 4:],
+ X_NOISE: x,
+ X: dat,
+ }
for i in range(2):
x = sess.run([x_final], feed_dict=feed_dict)[0]
@@ -461,11 +594,25 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
latent_pos = latents_gen[:10, 4:]
if FLAGS.joint_shape:
- feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
+ feed_dict = {
+ X_NOISE: data_init,
+ LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32) - 1],
+ LABEL_POS: latent_pos,
+ }
elif FLAGS.joint_rot:
- feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
+ feed_dict = {
+ LABEL_ROT: np.concatenate(
+ [np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1
+ ),
+ LABEL_POS: latent[:10, 4:],
+ X_NOISE: data_init,
+ }
else:
- feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}
+ feed_dict = {
+ X_NOISE: data_init,
+ LABEL_SIZE: latent_scale,
+ LABEL_POS: latent_pos,
+ }
x_output = sess.run([x_final], feed_dict=feed_dict)[0]
@@ -480,27 +627,28 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
x_output_wrap[:, 1:-1, 1:-1] = x_output
data_try_wrap[:, 1:-1, 1:-1] = data_try
- im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
+ im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
+ -1, 66 * 2
+ )
impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output)
print("Successfully saved images at {}".format(impath))
-
def conceptcombine(sess, kvs, data, latents, save_exp_dir):
- X_NOISE = kvs['X_NOISE']
- LABEL_SIZE = kvs['LABEL_SIZE']
- LABEL_SHAPE = kvs['LABEL_SHAPE']
- LABEL_POS = kvs['LABEL_POS']
- LABEL_ROT = kvs['LABEL_ROT']
- model_size = kvs['model_size']
- model_shape = kvs['model_shape']
- model_pos = kvs['model_pos']
- model_rot = kvs['model_rot']
- weight_size = kvs['weight_size']
- weight_shape = kvs['weight_shape']
- weight_pos = kvs['weight_pos']
- weight_rot = kvs['weight_rot']
+ X_NOISE = kvs["X_NOISE"]
+ LABEL_SIZE = kvs["LABEL_SIZE"]
+ LABEL_SHAPE = kvs["LABEL_SHAPE"]
+ LABEL_POS = kvs["LABEL_POS"]
+ LABEL_ROT = kvs["LABEL_ROT"]
+ model_size = kvs["model_size"]
+ model_shape = kvs["model_shape"]
+ model_pos = kvs["model_pos"]
+ model_rot = kvs["model_rot"]
+ weight_size = kvs["weight_size"]
+ weight_shape = kvs["weight_shape"]
+ weight_pos = kvs["weight_pos"]
+ weight_rot = kvs["weight_rot"]
x_mod = X_NOISE
for i in range(FLAGS.num_steps):
@@ -540,13 +688,18 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
data_try = data[:10]
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
label_scale = latents[:10, 2:3]
- label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)]
+ label_shape = np.eye(3)[(latents[:10, 1] - 1).astype(np.uint8)]
label_rot = latents[:10, 3:4]
label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
label_pos = latents[:10, 4:]
- feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos,
- LABEL_ROT: label_rot}
+ feed_dict = {
+ X_NOISE: data_init,
+ LABEL_SIZE: label_scale,
+ LABEL_SHAPE: label_shape,
+ LABEL_POS: label_pos,
+ LABEL_ROT: label_rot,
+ }
x_out = sess.run([x_final], feed_dict)[0]
im_name = "im"
@@ -569,14 +722,15 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
x_out_pad[:, 1:-1, 1:-1] = x_out
data_try_pad[:, 1:-1, 1:-1] = data_try
- im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2)
+ im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66 * 2)
impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output)
print("Successfully saved images at {}".format(impath))
+
def main():
- data = np.load(FLAGS.dsprites_path)['imgs']
- l = latents = np.load(FLAGS.dsprites_path)['latents_values']
+ data = np.load(FLAGS.dsprites_path)["imgs"]
+ l = latents = np.load(FLAGS.dsprites_path)["latents_values"]
np.random.seed(1)
idx = np.random.permutation(data.shape[0])
@@ -589,52 +743,74 @@ def main():
# Model 1 will be conditioned on size
model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
- weight_size = model_size.construct_weights('context_0')
+ weight_size = model_size.construct_weights("context_0")
# Model 2 will be conditioned on shape
model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
- weight_shape = model_shape.construct_weights('context_1')
+ weight_shape = model_shape.construct_weights("context_1")
# Model 3 will be conditioned on position
model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
- weight_pos = model_pos.construct_weights('context_2')
+ weight_pos = model_pos.construct_weights("context_2")
# Model 4 will be conditioned on rotation
model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
- weight_rot = model_rot.construct_weights('context_3')
+ weight_rot = model_rot.construct_weights("context_3")
sess.run(tf.global_variables_initializer())
- save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))
+ save_path_size = osp.join(
+ FLAGS.logdir, FLAGS.exp_size, "model_{}".format(FLAGS.resume_size)
+ )
- v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
- v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}
+ v_list = tf.get_collection(
+ tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(0)
+ )
+ v_map = {
+ (v.name.replace("context_{}".format(0), "context_0")[:-2]): v for v in v_list
+ }
if FLAGS.cond_scale:
saver = tf.train.Saver(v_map)
saver.restore(sess, save_path_size)
- save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))
+ save_path_shape = osp.join(
+ FLAGS.logdir, FLAGS.exp_shape, "model_{}".format(FLAGS.resume_shape)
+ )
- v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
- v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
+ v_list = tf.get_collection(
+ tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(1)
+ )
+ v_map = {
+ (v.name.replace("context_{}".format(1), "context_0")[:-2]): v for v in v_list
+ }
if FLAGS.cond_shape:
saver = tf.train.Saver(v_map)
saver.restore(sess, save_path_shape)
-
- save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
- v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
- v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
+ save_path_pos = osp.join(
+ FLAGS.logdir, FLAGS.exp_pos, "model_{}".format(FLAGS.resume_pos)
+ )
+ v_list = tf.get_collection(
+ tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(2)
+ )
+ v_map = {
+ (v.name.replace("context_{}".format(2), "context_0")[:-2]): v for v in v_list
+ }
saver = tf.train.Saver(v_map)
if FLAGS.cond_pos:
saver.restore(sess, save_path_pos)
-
- save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
- v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
- v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
+ save_path_rot = osp.join(
+ FLAGS.logdir, FLAGS.exp_rot, "model_{}".format(FLAGS.resume_rot)
+ )
+ v_list = tf.get_collection(
+ tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(3)
+ )
+ v_map = {
+ (v.name.replace("context_{}".format(3), "context_0")[:-2]): v for v in v_list
+ }
saver = tf.train.Saver(v_map)
if FLAGS.cond_rot:
@@ -646,53 +822,57 @@ def main():
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
- x_mod = X_NOISE
-
kvs = {}
- kvs['X_NOISE'] = X_NOISE
- kvs['LABEL_SIZE'] = LABEL_SIZE
- kvs['LABEL_SHAPE'] = LABEL_SHAPE
- kvs['LABEL_POS'] = LABEL_POS
- kvs['LABEL_ROT'] = LABEL_ROT
- kvs['model_size'] = model_size
- kvs['model_shape'] = model_shape
- kvs['model_pos'] = model_pos
- kvs['model_rot'] = model_rot
- kvs['weight_size'] = weight_size
- kvs['weight_shape'] = weight_shape
- kvs['weight_pos'] = weight_pos
- kvs['weight_rot'] = weight_rot
+ kvs["X_NOISE"] = X_NOISE
+ kvs["LABEL_SIZE"] = LABEL_SIZE
+ kvs["LABEL_SHAPE"] = LABEL_SHAPE
+ kvs["LABEL_POS"] = LABEL_POS
+ kvs["LABEL_ROT"] = LABEL_ROT
+ kvs["model_size"] = model_size
+ kvs["model_shape"] = model_shape
+ kvs["model_pos"] = model_pos
+ kvs["model_rot"] = model_rot
+ kvs["weight_size"] = weight_size
+ kvs["weight_shape"] = weight_shape
+ kvs["weight_pos"] = weight_pos
+ kvs["weight_rot"] = weight_rot
- save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
+ save_exp_dir = osp.join(
+ FLAGS.savedir, "{}_{}_joint".format(FLAGS.exp_size, FLAGS.exp_shape)
+ )
if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir)
-
- if FLAGS.task == 'conceptcombine':
+ if FLAGS.task == "conceptcombine":
conceptcombine(sess, kvs, data, latents, save_exp_dir)
- elif FLAGS.task == 'labeldiscover':
+ elif FLAGS.task == "labeldiscover":
labeldiscover(sess, kvs, data, latents, save_exp_dir)
- elif FLAGS.task == 'gentest':
- save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
+ elif FLAGS.task == "gentest":
+ save_exp_dir = osp.join(
+ FLAGS.savedir, "{}_{}_gen".format(FLAGS.exp_size, FLAGS.exp_pos)
+ )
if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir)
gentest(sess, kvs, data, latents, save_exp_dir)
- elif FLAGS.task == 'genbaseline':
- save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
+ elif FLAGS.task == "genbaseline":
+ save_exp_dir = osp.join(
+ FLAGS.savedir, "{}_{}_gen_baseline".format(FLAGS.exp_size, FLAGS.exp_pos)
+ )
if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir)
if FLAGS.plot_curve:
mse_losses = []
- for frac in [i/10 for i in range(11)]:
- mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
+ for frac in [i / 10 for i in range(11)]:
+ mse_loss = genbaseline(
+ sess, kvs, data, latents, save_exp_dir, frac=frac
+ )
mse_losses.append(mse_loss)
np.save("mse_baseline_comb.npy", mse_losses)
else:
genbaseline(sess, kvs, data, latents, save_exp_dir)
-
if __name__ == "__main__":
main()
diff --git a/EBMs/ebm_sandbox.py b/EBMs/ebm_sandbox.py
index 531d517..c5c5ed2 100644
--- a/EBMs/ebm_sandbox.py
+++ b/EBMs/ebm_sandbox.py
@@ -1,122 +1,148 @@
+import os
+import os.path as osp
+
+import numpy as np
+import sklearn.metrics as sk
import tensorflow as tf
-import math
-from tqdm import tqdm
+from baselines.common.tf_util import initialize
+from data import Cifar10, Cifar100, DSprites, Imagenet, Svhn, Textures
+from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider
+from scipy.misc import imsave
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader
-import torch
-from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, DspritesNet
-from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, DSprites
-from utils import optimistic_restore, set_seed
-import os.path as osp
-import numpy as np
-from baselines.logger import TensorBoardOutputFormat
-from scipy.misc import imsave
-import os
-import sklearn.metrics as sk
-from baselines.common.tf_util import initialize
-from scipy.linalg import eig
-import matplotlib.pyplot as plt
+from tqdm import tqdm
+from utils import optimistic_restore
# set_seed(1)
-flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
-flags.DEFINE_string('dataset', 'cifar10', 'omniglot or imagenet or omniglotfull or cifar10 or mnist or dsprites')
-flags.DEFINE_string('logdir', 'sandbox_cachedir', 'location where log of experiments will be stored')
-flags.DEFINE_string('task', 'label', 'using conditional energy based models for classification'
- 'anticorrupt: restore salt and pepper noise),'
- ' boxcorrupt: restore empty portion of image'
- 'or crossclass: change images from one class to another'
- 'or cycleclass: view image change across a label'
- 'or nearestneighbor which returns the nearest images in the test set'
- 'or latent to traverse the latent space of an EBM through eigenvectors of the hessian (dsprites only)'
- 'or mixenergy to evaluate out of distribution generalization compared to other datasets')
-flags.DEFINE_bool('hessian', True, 'Whether to use the hessian or the Jacobian for latent traversals')
-flags.DEFINE_string('exp', 'default', 'name of experiments')
-flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
-flags.DEFINE_integer('batch_size', 32, 'Size of inputs')
-flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
-flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
-flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
-flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
-flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
-flags.DEFINE_bool('train', True, 'Whether to train or test network')
-flags.DEFINE_bool('single', False, 'whether to use one sample to debug')
-flags.DEFINE_bool('cclass', True, 'whether to use a conditional model (required for task label)')
-flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
-flags.DEFINE_float('step_lr', 10.0, 'step size for updates on label')
-flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
-flags.DEFINE_bool('large_model', False, 'Whether to use a large model')
-flags.DEFINE_bool('larger_model', False, 'Whether to use a larger model')
-flags.DEFINE_bool('wider_model', False, 'Whether to use a widermodel model')
-flags.DEFINE_bool('svhn', False, 'Whether to test on SVHN')
+flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
+flags.DEFINE_string(
+ "dataset",
+ "cifar10",
+ "omniglot or imagenet or omniglotfull or cifar10 or mnist or dsprites",
+)
+flags.DEFINE_string(
+ "logdir", "sandbox_cachedir", "location where log of experiments will be stored"
+)
+flags.DEFINE_string(
+ "task",
+ "label",
+ "using conditional energy based models for classification"
+ "anticorrupt: restore salt and pepper noise),"
+ " boxcorrupt: restore empty portion of image"
+ "or crossclass: change images from one class to another"
+ "or cycleclass: view image change across a label"
+ "or nearestneighbor which returns the nearest images in the test set"
+ "or latent to traverse the latent space of an EBM through eigenvectors of the hessian (dsprites only)"
+ "or mixenergy to evaluate out of distribution generalization compared to other datasets",
+)
+flags.DEFINE_bool(
+ "hessian", True, "Whether to use the hessian or the Jacobian for latent traversals"
+)
+flags.DEFINE_string("exp", "default", "name of experiments")
+flags.DEFINE_integer(
+ "data_workers", 5, "Number of different data workers to load data in parallel"
+)
+flags.DEFINE_integer("batch_size", 32, "Size of inputs")
+flags.DEFINE_integer("resume_iter", -1, "iteration to resume training from")
+flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
+flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
+flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
+flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
+flags.DEFINE_bool("train", True, "Whether to train or test network")
+flags.DEFINE_bool("single", False, "whether to use one sample to debug")
+flags.DEFINE_bool(
+ "cclass", True, "whether to use a conditional model (required for task label)"
+)
+flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
+flags.DEFINE_float("step_lr", 10.0, "step size for updates on label")
+flags.DEFINE_float("proj_norm", 0.0, "Maximum change of input images")
+flags.DEFINE_bool("large_model", False, "Whether to use a large model")
+flags.DEFINE_bool("larger_model", False, "Whether to use a larger model")
+flags.DEFINE_bool("wider_model", False, "Whether to use a widermodel model")
+flags.DEFINE_bool("svhn", False, "Whether to test on SVHN")
# Conditions for mixenergy (outlier detection)
-flags.DEFINE_bool('svhnmix', False, 'Whether to test mix on SVHN')
-flags.DEFINE_bool('cifar100mix', False, 'Whether to test mix on CIFAR100')
-flags.DEFINE_bool('texturemix', False, 'Whether to test mix on Textures dataset')
-flags.DEFINE_bool('randommix', False, 'Whether to test mix on random dataset')
+flags.DEFINE_bool("svhnmix", False, "Whether to test mix on SVHN")
+flags.DEFINE_bool("cifar100mix", False, "Whether to test mix on CIFAR100")
+flags.DEFINE_bool("texturemix", False, "Whether to test mix on Textures dataset")
+flags.DEFINE_bool("randommix", False, "Whether to test mix on random dataset")
# Conditions for label task (adversarial classification)
-flags.DEFINE_integer('lival', 8, 'Value of constraint for li')
-flags.DEFINE_integer('l2val', 40, 'Value of constraint for l2')
-flags.DEFINE_integer('pgd', 0, 'number of steps project gradient descent to run')
-flags.DEFINE_integer('lnorm', -1, 'linfinity is -1, l2 norm is 2')
-flags.DEFINE_bool('labelgrid', False, 'Make a grid of labels')
+flags.DEFINE_integer("lival", 8, "Value of constraint for li")
+flags.DEFINE_integer("l2val", 40, "Value of constraint for l2")
+flags.DEFINE_integer("pgd", 0, "number of steps project gradient descent to run")
+flags.DEFINE_integer("lnorm", -1, "linfinity is -1, l2 norm is 2")
+flags.DEFINE_bool("labelgrid", False, "Make a grid of labels")
# Conditions on which models to use
-flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
-flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
-flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
-flags.DEFINE_bool('cond_size', True, 'whether to condition on scale')
+flags.DEFINE_bool("cond_pos", True, "whether to condition on position")
+flags.DEFINE_bool("cond_rot", True, "whether to condition on rotation")
+flags.DEFINE_bool("cond_shape", True, "whether to condition on shape")
+flags.DEFINE_bool("cond_size", True, "whether to condition on scale")
FLAGS = flags.FLAGS
+
def rescale_im(im):
im = np.clip(im, 0, 1)
return np.round(im * 255).astype(np.uint8)
+
def label(dataloader, test_dataloader, target_vars, sess, l1val=8, l2val=40):
- X = target_vars['X']
- Y = target_vars['Y']
- Y_GT = target_vars['Y_GT']
- accuracy = target_vars['accuracy']
- train_op = target_vars['train_op']
- l1_norm = target_vars['l1_norm']
- l2_norm = target_vars['l2_norm']
+ X = target_vars["X"]
+ Y = target_vars["Y"]
+ Y_GT = target_vars["Y_GT"]
+ accuracy = target_vars["accuracy"]
+ target_vars["train_op"]
+ l1_norm = target_vars["l1_norm"]
+ l2_norm = target_vars["l2_norm"]
label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10))
label_init = label_init / label_init.sum(axis=1, keepdims=True)
- label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1))
+ label_init = np.tile(np.eye(10)[None:, :], (FLAGS.batch_size, 1, 1))
label_init = np.reshape(label_init, (-1, 10))
for i in range(1):
emp_accuracies = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
- feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val}
+ feed_dict = {
+ X: data,
+ Y_GT: label_gt,
+ Y: label_init,
+ l1_norm: l1val,
+ l2_norm: l2val,
+ }
emp_accuracy = sess.run([accuracy], feed_dict)
emp_accuracies.append(emp_accuracy)
print(np.array(emp_accuracies).mean())
- print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val))
+ print(
+ "Received total accuracy of {} for li of {} and l2 of {}".format(
+ np.array(emp_accuracies).mean(), l1val, l2val
+ )
+ )
return np.array(emp_accuracies).mean()
-def labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=8, l2val=40):
- X = target_vars['X']
- Y = target_vars['Y']
- Y_GT = target_vars['Y_GT']
- accuracy = target_vars['accuracy']
- train_op = target_vars['train_op']
- l1_norm = target_vars['l1_norm']
- l2_norm = target_vars['l2_norm']
+def labelfinetune(
+ dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=8, l2val=40
+):
+ X = target_vars["X"]
+ Y = target_vars["Y"]
+ Y_GT = target_vars["Y_GT"]
+ accuracy = target_vars["accuracy"]
+ train_op = target_vars["train_op"]
+ l1_norm = target_vars["l1_norm"]
+ l2_norm = target_vars["l2_norm"]
label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10))
label_init = label_init / label_init.sum(axis=1, keepdims=True)
- label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1))
+ label_init = np.tile(np.eye(10)[None:, :], (FLAGS.batch_size, 1, 1))
label_init = np.reshape(label_init, (-1, 10))
itr = 0
@@ -136,27 +162,35 @@ def labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver
saver.restore(sess, osp.join(savedir, "model_supervised"))
-
for i in range(1):
emp_accuracies = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
- feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val}
+ feed_dict = {
+ X: data,
+ Y_GT: label_gt,
+ Y: label_init,
+ l1_norm: l1val,
+ l2_norm: l2val,
+ }
emp_accuracy = sess.run([accuracy], feed_dict)
emp_accuracies.append(emp_accuracy)
print(np.array(emp_accuracies).mean())
-
- print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val))
+ print(
+ "Received total accuracy of {} for li of {} and l2 of {}".format(
+ np.array(emp_accuracies).mean(), l1val, l2val
+ )
+ )
return np.array(emp_accuracies).mean()
def energyeval(dataloader, test_dataloader, target_vars, sess):
- X = target_vars['X']
- Y_GT = target_vars['Y_GT']
- energy = target_vars['energy']
- energy_end = target_vars['energy_end']
+ X = target_vars["X"]
+ Y_GT = target_vars["Y_GT"]
+ energy = target_vars["energy"]
+ target_vars["energy_end"]
test_energies = []
train_energies = []
@@ -173,29 +207,55 @@ def energyeval(dataloader, test_dataloader, target_vars, sess):
print(len(train_energies))
print(len(test_energies))
- print("Train energies of {} with std {}".format(np.mean(train_energies), np.std(train_energies)))
- print("Test energies of {} with std {}".format(np.mean(test_energies), np.std(test_energies)))
+ print(
+ "Train energies of {} with std {}".format(
+ np.mean(train_energies), np.std(train_energies)
+ )
+ )
+ print(
+ "Test energies of {} with std {}".format(
+ np.mean(test_energies), np.std(test_energies)
+ )
+ )
np.save("train_ebm.npy", train_energies)
np.save("test_ebm.npy", test_energies)
def energyevalmix(dataloader, test_dataloader, target_vars, sess):
- X = target_vars['X']
- Y_GT = target_vars['Y_GT']
- energy = target_vars['energy']
+ X = target_vars["X"]
+ Y_GT = target_vars["Y_GT"]
+ energy = target_vars["energy"]
if FLAGS.svhnmix:
dataset = Svhn(train=False)
- test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
+ test_dataloader_val = DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size,
+ num_workers=FLAGS.data_workers,
+ shuffle=True,
+ drop_last=False,
+ )
test_iter = iter(test_dataloader_val)
elif FLAGS.cifar100mix:
dataset = Cifar100(train=False)
- test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
+ test_dataloader_val = DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size,
+ num_workers=FLAGS.data_workers,
+ shuffle=True,
+ drop_last=False,
+ )
test_iter = iter(test_dataloader_val)
elif FLAGS.texturemix:
dataset = Textures()
- test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
+ test_dataloader_val = DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size,
+ num_workers=FLAGS.data_workers,
+ shuffle=True,
+ drop_last=False,
+ )
test_iter = iter(test_dataloader_val)
probs = []
@@ -218,11 +278,11 @@ def energyevalmix(dataloader, test_dataloader, target_vars, sess):
data_other = data[data_idx]
data_mix = (data + data_other) / 2
- data_mix = data_mix[:data.shape[0]]
+ data_mix = data_mix[: data.shape[0]]
if FLAGS.cclass:
# It's unfair to take a random class
- label_gt= np.tile(np.eye(10), (data.shape[0], 1, 1))
+ label_gt = np.tile(np.eye(10), (data.shape[0], 1, 1))
label_gt = label_gt.reshape(data.shape[0] * 10, 10)
data_mix = np.tile(data_mix[:, None, :, :, :], (1, 10, 1, 1, 1))
data = np.tile(data[:, None, :, :, :], (1, 10, 1, 1, 1))
@@ -230,7 +290,6 @@ def energyevalmix(dataloader, test_dataloader, target_vars, sess):
data_mix = data_mix.reshape(-1, 32, 32, 3)
data = data.reshape(-1, 32, 32, 3)
-
feed_dict = {X: data, Y_GT: label_gt}
feed_dict_neg = {X: data_mix, Y_GT: label_gt}
@@ -241,12 +300,12 @@ def energyevalmix(dataloader, test_dataloader, target_vars, sess):
pos_energy = pos_energy.reshape(-1, 10).min(axis=1)
neg_energy = neg_energy.reshape(-1, 10).min(axis=1)
- probs.extend(list(-1*pos_energy))
- probs.extend(list(-1*neg_energy))
- pos.extend(list(-1*pos_energy))
- negs.extend(list(-1*neg_energy))
- labels.extend([1]*pos_energy.shape[0])
- labels.extend([0]*neg_energy.shape[0])
+ probs.extend(list(-1 * pos_energy))
+ probs.extend(list(-1 * neg_energy))
+ pos.extend(list(-1 * pos_energy))
+ negs.extend(list(-1 * neg_energy))
+ labels.extend([1] * pos_energy.shape[0])
+ labels.extend([0] * neg_energy.shape[0])
pos, negs = np.array(pos), np.array(negs)
np.save("pos.npy", pos)
@@ -256,11 +315,13 @@ def energyevalmix(dataloader, test_dataloader, target_vars, sess):
def anticorrupt(dataloader, weights, model, target_vars, logdir, sess):
- X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
+ X, Y_GT, X_final = target_vars["X"], target_vars["Y_GT"], target_vars["X_final"]
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
- noise = np.random.uniform(0, 1, size=[data.shape[0], data.shape[1], data.shape[2]])
+ noise = np.random.uniform(
+ 0, 1, size=[data.shape[0], data.shape[1], data.shape[2]]
+ )
low_mask = noise < 0.05
high_mask = (noise > 0.05) & (noise < 0.1)
@@ -276,33 +337,34 @@ def anticorrupt(dataloader, weights, model, target_vars, logdir, sess):
data_corrupt = sess.run([X_final], feed_dict)[0]
data_uncorrupt = data_corrupt
- data_corrupt, data_uncorrupt, data = rescale_im(data_corrupt_init), rescale_im(data_uncorrupt), rescale_im(data)
+ data_corrupt, data_uncorrupt, data = (
+ rescale_im(data_corrupt_init),
+ rescale_im(data_uncorrupt),
+ rescale_im(data),
+ )
- panel_im = np.zeros((32*20, 32*3, 3)).astype(np.uint8)
+ panel_im = np.zeros((32 * 20, 32 * 3, 3)).astype(np.uint8)
for i in range(20):
- panel_im[32*i:32*i+32, :32] = data_corrupt[i]
- panel_im[32*i:32*i+32, 32:64] = data_uncorrupt[i]
- panel_im[32*i:32*i+32, 64:] = data[i]
+ panel_im[32 * i : 32 * i + 32, :32] = data_corrupt[i]
+ panel_im[32 * i : 32 * i + 32, 32:64] = data_uncorrupt[i]
+ panel_im[32 * i : 32 * i + 32, 64:] = data[i]
imsave(osp.join(logdir, "anticorrupt.png"), panel_im)
assert False
def boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess):
- X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
+ X, Y_GT, X_final = target_vars["X"], target_vars["Y_GT"], target_vars["X_final"]
eval_im = 10000
data_diff = []
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
- data_uncorrupts = []
data_corrupt = data.copy()
data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3))
- data_corrupt_init = data_corrupt
-
for j in range(10):
feed_dict = {X: data_corrupt, Y_GT: label_gt}
data_corrupt = sess.run([X_final], feed_dict)[0]
@@ -313,7 +375,11 @@ def boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir,
if len(data_diff) > eval_im:
break
- print("Mean {} and std {} for train dataloader".format(np.mean(data_diff), np.std(data_diff)))
+ print(
+ "Mean {} and std {} for train dataloader".format(
+ np.mean(data_diff), np.std(data_diff)
+ )
+ )
np.save("data_diff_train_image.npy", data_diff)
@@ -321,13 +387,10 @@ def boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir,
for data_corrupt, data, label_gt in tqdm(test_dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
- data_uncorrupts = []
data_corrupt = data.copy()
data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3))
- data_corrupt_init = data_corrupt
-
for j in range(10):
feed_dict = {X: data_corrupt, Y_GT: label_gt}
data_corrupt = sess.run([X_final], feed_dict)[0]
@@ -337,13 +400,22 @@ def boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir,
if len(data_diff) > eval_im:
break
- print("Mean {} and std {} for test dataloader".format(np.mean(data_diff), np.std(data_diff)))
+ print(
+ "Mean {} and std {} for test dataloader".format(
+ np.mean(data_diff), np.std(data_diff)
+ )
+ )
np.save("data_diff_test_image.npy", data_diff)
def crossclass(dataloader, weights, model, target_vars, logdir, sess):
- X, Y_GT, X_mods, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_mods'], target_vars['X_final']
+ X, Y_GT, X_mods, X_final = (
+ target_vars["X"],
+ target_vars["Y_GT"],
+ target_vars["X_mods"],
+ target_vars["X_final"],
+ )
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_corrupt = data.copy()
@@ -359,21 +431,21 @@ def crossclass(dataloader, weights, model, target_vars, logdir, sess):
feed_dict = {X: data_mod, Y_GT: label_gt}
data_mod = sess.run(X_final, feed_dict)
-
-
data_corrupt, data = rescale_im(data_corrupt), rescale_im(data)
data_mods = [rescale_im(data_mod) for data_mod in data_mods]
- panel_im = np.zeros((32*20, 32*(len(data_mods) + 2), 3)).astype(np.uint8)
+ panel_im = np.zeros((32 * 20, 32 * (len(data_mods) + 2), 3)).astype(np.uint8)
for i in range(20):
- panel_im[32*i:32*i+32, :32] = data_corrupt[i]
+ panel_im[32 * i : 32 * i + 32, :32] = data_corrupt[i]
for j in range(len(data_mods)):
- panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i]
+ panel_im[32 * i : 32 * i + 32, 32 * (j + 1) : 32 * (j + 2)] = data_mods[
+ j
+ ][i]
- panel_im[32*i:32*i+32, -32:] = data[i]
+ panel_im[32 * i : 32 * i + 32, -32:] = data[i]
imsave(osp.join(logdir, "crossclass.png"), panel_im)
assert False
@@ -381,18 +453,16 @@ def crossclass(dataloader, weights, model, target_vars, logdir, sess):
def cycleclass(dataloader, weights, model, target_vars, logdir, sess):
# X, Y_GT, X_final, X_targ = target_vars['X'], target_vars['Y_GT'], target_vars['X_final'], target_vars['X_targ']
- X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
+ X, Y_GT, X_final = target_vars["X"], target_vars["Y_GT"], target_vars["X_final"]
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_corrupt = data_corrupt.numpy()
-
data_mods = []
x_curr = data_corrupt
- x_target = np.random.uniform(0, 1, data_corrupt.shape)
+ np.random.uniform(0, 1, data_corrupt.shape)
# x_target = np.tile(x_target, (1, 32, 32, 1))
-
for i in range(20):
feed_dict = {X: x_curr, Y_GT: label_gt}
x_curr_new = sess.run(X_final, feed_dict)
@@ -400,32 +470,34 @@ def cycleclass(dataloader, weights, model, target_vars, logdir, sess):
data_mods.append(x_curr_new)
if i > 30:
- x_target = np.random.uniform(0, 1, data_corrupt.shape)
+ np.random.uniform(0, 1, data_corrupt.shape)
data_corrupt, data = rescale_im(data_corrupt), rescale_im(data)
data_mods = [rescale_im(data_mod) for data_mod in data_mods]
- panel_im = np.zeros((32*100, 32*(len(data_mods) + 2), 3)).astype(np.uint8)
+ panel_im = np.zeros((32 * 100, 32 * (len(data_mods) + 2), 3)).astype(np.uint8)
for i in range(100):
- panel_im[32*i:32*i+32, :32] = data_corrupt[i]
+ panel_im[32 * i : 32 * i + 32, :32] = data_corrupt[i]
for j in range(len(data_mods)):
- panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i]
+ panel_im[32 * i : 32 * i + 32, 32 * (j + 1) : 32 * (j + 2)] = data_mods[
+ j
+ ][i]
- panel_im[32*i:32*i+32, -32:] = data[i]
+ panel_im[32 * i : 32 * i + 32, -32:] = data[i]
imsave(osp.join(logdir, "cycleclass.png"), panel_im)
assert False
def democlass(dataloader, weights, model, target_vars, logdir, sess):
- X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
- panel_im = np.zeros((5*32, 10*32, 3)).astype(np.uint8)
+ X, Y_GT, X_final = target_vars["X"], target_vars["Y_GT"], target_vars["X_final"]
+ panel_im = np.zeros((5 * 32, 10 * 32, 3)).astype(np.uint8)
for i in range(10):
data_corrupt = np.random.uniform(0, 1, (5, 32, 32, 3))
- label_gt = np.tile(np.eye(10)[i:i+1], (5, 1))
+ label_gt = np.tile(np.eye(10)[i : i + 1], (5, 1))
feed_dict = {X: data_corrupt, Y_GT: label_gt}
x_final = sess.run([X_final], feed_dict)[0]
@@ -439,7 +511,9 @@ def democlass(dataloader, weights, model, target_vars, logdir, sess):
row_idx = row * 32
for j in range(5):
- panel_im[row_idx:row_idx+32, start_idx+j*32:start_idx+(j+1) * 32] = x_final[j]
+ panel_im[
+ row_idx : row_idx + 32, start_idx + j * 32 : start_idx + (j + 1) * 32
+ ] = x_final[j]
imsave(osp.join(logdir, "democlass.png"), panel_im)
@@ -452,10 +526,10 @@ def construct_finetune_label(weight, X, Y, Y_GT, model, target_vars):
batch_size = tf.shape(X)[0]
X = tf.reshape(X, (batch_size, 1, 32, 32, 3))
X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3))
- Y_new = tf.reshape(Y, (batch_size*10, 10))
+ Y_new = tf.reshape(Y, (batch_size * 10, 10))
- X_min = X - 8 / 255.
- X_max = X + 8 / 255.
+ X_min = X - 8 / 255.0
+ X_max = X + 8 / 255.0
for i in range(num_steps):
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
@@ -463,7 +537,6 @@ def construct_finetune_label(weight, X, Y, Y_GT, model, target_vars):
energy_noise = model.forward(X, weights, label=Y, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
-
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
@@ -485,58 +558,62 @@ def construct_finetune_label(weight, X, Y, Y_GT, model, target_vars):
print("Constructed loop {} of pgd attack".format(i))
X_init = X
if i == 0:
- X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255.
+ X = (
+ X
+ + tf.to_float(
+ tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)
+ )
+ / 255.0
+ )
logit = compute_logit(X)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit)
- x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.
+ x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.0
X = X + 2 * x_grad
if FLAGS.lnorm == -1:
X = tf.maximum(tf.minimum(X, X_max), X_min)
elif FLAGS.lnorm == 2:
- X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3])
-
+ X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255.0, axes=[1, 2, 3])
energy = compute_logit(X, num_steps=0)
logits = energy
labels = tf.argmax(Y_GT, axis=1)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logits)
-
optimizer = tf.train.AdamOptimizer(1e-3)
train_op = optimizer.minimize(loss)
accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, axis=1), labels)
- target_vars['accuracy'] = accuracy
- target_vars['train_op'] = train_op
- target_vars['l1_norm'] = l1_norm
- target_vars['l2_norm'] = l2_norm
+ target_vars["accuracy"] = accuracy
+ target_vars["train_op"] = train_op
+ target_vars["l1_norm"] = l1_norm
+ target_vars["l2_norm"] = l2_norm
def construct_latent(weights, X, Y_GT, model, target_vars):
- eps = 0.001
X_init = X[0:1]
def traversals(model, X, weights, Y_GT):
if FLAGS.hessian:
e_pos = model.forward(X, weights, label=Y_GT)
hessian = tf.hessians(e_pos, X)
- hessian = tf.reshape(hessian, (1, 64*64, 64*64))[0]
+ hessian = tf.reshape(hessian, (1, 64 * 64, 64 * 64))[0]
e, v = tf.linalg.eigh(hessian)
else:
latent = model.forward(X, weights, label=Y_GT, return_logit=True)
latents = tf.split(latent, 128, axis=1)
jacobian = [tf.gradients(latent, X)[0] for latent in latents]
jacobian = tf.stack(jacobian, axis=1)
- jacobian = tf.reshape(jacobian, (tf.shape(jacobian)[1], tf.shape(jacobian)[1], 64*64))
+ jacobian = tf.reshape(
+ jacobian, (tf.shape(jacobian)[1], tf.shape(jacobian)[1], 64 * 64)
+ )
s, _, v = tf.linalg.svd(jacobian)
return v
-
var_scale = 1.0
n = 3
xs = []
@@ -559,7 +636,7 @@ def construct_latent(weights, X, Y_GT, model, target_vars):
e_pos = model.forward(x_stack, weights, label=Y_GT)
x_grad = tf.gradients(e_pos, [x_stack])[0]
- x_stack = x_stack - 4*FLAGS.step_lr * x_grad
+ x_stack = x_stack - 4 * FLAGS.step_lr * x_grad
x_stack = tf.clip_by_value(x_stack, 0, 1)
@@ -589,11 +666,13 @@ def construct_latent(weights, X, Y_GT, model, target_vars):
energys = []
for i in range(20):
- x_mods_stack = x_mods_stack + tf.random_normal(tf.shape(x_mods_stack), mean=0.0, stddev=0.005)
+ x_mods_stack = x_mods_stack + tf.random_normal(
+ tf.shape(x_mods_stack), mean=0.0, stddev=0.005
+ )
e_pos = model.forward(x_mods_stack, weights, label=Y_GT)
x_grad = tf.gradients(e_pos, [x_mods_stack])[0]
- x_mods_stack = x_mods_stack - 4*FLAGS.step_lr * x_grad
+ x_mods_stack = x_mods_stack - 4 * FLAGS.step_lr * x_grad
# x_mods_stack = x_mods_stack + 0.1 * eigs_stack
x_mods_stack = tf.clip_by_value(x_mods_stack, 0, 1)
@@ -605,29 +684,29 @@ def construct_latent(weights, X, Y_GT, model, target_vars):
# target_vars['hessian'] = hessian
# target_vars['e'] = e
- target_vars['v'] = v
- target_vars['x_stack'] = x_stack
- target_vars['x_refine'] = x_refine
- target_vars['es'] = es
+ target_vars["v"] = v
+ target_vars["x_stack"] = x_stack
+ target_vars["x_refine"] = x_refine
+ target_vars["es"] = es
# target_vars['e_base'] = e_pos_base
def latent(test_dataloader, weights, model, target_vars, sess):
- X = target_vars['X']
- Y_GT = target_vars['Y_GT']
+ X = target_vars["X"]
+ target_vars["Y_GT"]
# hessian = target_vars['hessian']
# e = target_vars['e']
- v = target_vars['v']
- x_stack = target_vars['x_stack']
- x_refine = target_vars['x_refine']
- es = target_vars['es']
+ target_vars["v"]
+ x_stack = target_vars["x_stack"]
+ x_refine = target_vars["x_refine"]
+ es = target_vars["es"]
# e_pos_base = target_vars['e_base']
# e_pos_hess_modify = target_vars['e_pos_hessian']
data_corrupt, data, label_gt = iter(test_dataloader).next()
data = data.numpy()
x_init = np.tile(data[0:1], (6, 1, 1))
- x_mod, = sess.run([x_stack], {X: data})
+ (x_mod,) = sess.run([x_stack], {X: data})
# print("Value of original starting image: ", e_pos)
# print("Value of energy of hessian: ", e_pos_hess)
x_mod = x_mod.squeeze()
@@ -643,7 +722,6 @@ def latent(test_dataloader, weights, model, target_vars, sess):
x_mod_list = x_mod_list[:]
-
series_xmod = np.stack(x_mod_list, axis=1)
series_header = np.tile(data[0:1, None, :, :], (1, len(x_mod_list), 1, 1))
@@ -655,7 +733,9 @@ def latent(test_dataloader, weights, model, target_vars, sess):
series_total = series_total_full
- series_total = series_total.transpose((0, 2, 1, 3)).reshape((-1, len(x_mod_list)*66))
+ series_total = series_total.transpose((0, 2, 1, 3)).reshape(
+ (-1, len(x_mod_list) * 66)
+ )
im_total = rescale_im(series_total)
imsave("latent_comb.png", im_total)
@@ -671,7 +751,7 @@ def construct_label(weights, X, Y, Y_GT, model, target_vars):
# Y = Y / tf.reduce_sum(Y, axis=[1], keepdims=True)
- e_bias = tf.get_variable('e_bias', shape=10, initializer=tf.initializers.zeros())
+ e_bias = tf.get_variable("e_bias", shape=10, initializer=tf.initializers.zeros())
l1_norm = tf.placeholder(shape=(), dtype=tf.float32)
l2_norm = tf.placeholder(shape=(), dtype=tf.float32)
@@ -679,10 +759,10 @@ def construct_label(weights, X, Y, Y_GT, model, target_vars):
batch_size = tf.shape(X)[0]
X = tf.reshape(X, (batch_size, 1, 32, 32, 3))
X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3))
- Y_new = tf.reshape(Y, (batch_size*10, 10))
+ Y_new = tf.reshape(Y, (batch_size * 10, 10))
- X_min = X - 8 / 255.
- X_max = X + 8 / 255.
+ X_min = X - 8 / 255.0
+ X_max = X + 8 / 255.0
for i in range(num_steps):
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
@@ -690,7 +770,6 @@ def construct_label(weights, X, Y, Y_GT, model, target_vars):
energy_noise = model.forward(X, weights, label=Y, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
-
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
@@ -705,43 +784,52 @@ def construct_label(weights, X, Y, Y_GT, model, target_vars):
return energy
-
# eps_norm = 30
- X_min = X - l1_norm / 255.
- X_max = X + l1_norm / 255.
+ X_min = X - l1_norm / 255.0
+ X_max = X + l1_norm / 255.0
for i in range(FLAGS.pgd):
print("Constructed loop {} of pgd attack".format(i))
X_init = X
if i == 0:
- X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255.
+ X = (
+ X
+ + tf.to_float(
+ tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)
+ )
+ / 255.0
+ )
logit = compute_logit(X)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit)
- x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.
+ x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.0
X = X + 2 * x_grad
if FLAGS.lnorm == -1:
X = tf.maximum(tf.minimum(X, X_max), X_min)
elif FLAGS.lnorm == 2:
- X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3])
+ X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255.0, axes=[1, 2, 3])
- energy_stopped = compute_logit(X, stop_grad=True, num_steps=FLAGS.num_steps) + e_bias
+ energy_stopped = (
+ compute_logit(X, stop_grad=True, num_steps=FLAGS.num_steps) + e_bias
+ )
# # Y = tf.Print(Y, [Y])
labels = tf.argmax(Y_GT, axis=1)
# max_z = tf.argmax(energy_stopped, axis=1)
- loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=energy_stopped)
+ loss = tf.nn.softmax_cross_entropy_with_logits_v2(
+ labels=Y_GT, logits=energy_stopped
+ )
optimizer = tf.train.AdamOptimizer(1e-2)
train_op = optimizer.minimize(loss)
accuracy = tf.contrib.metrics.accuracy(tf.argmax(energy_stopped, axis=1), labels)
- target_vars['accuracy'] = accuracy
- target_vars['train_op'] = train_op
- target_vars['l1_norm'] = l1_norm
- target_vars['l2_norm'] = l2_norm
+ target_vars["accuracy"] = accuracy
+ target_vars["train_op"] = train_op
+ target_vars["l1_norm"] = l1_norm
+ target_vars["l2_norm"] = l2_norm
def construct_energy(weights, X, Y, Y_GT, model, target_vars):
@@ -759,9 +847,8 @@ def construct_energy(weights, X, Y, Y_GT, model, target_vars):
X = X - FLAGS.step_lr * x_grad
X = tf.clip_by_value(X, 0, 1)
-
- target_vars['energy'] = energy
- target_vars['energy_end'] = energy_noise
+ target_vars["energy"] = energy
+ target_vars["energy_end"] = energy_noise
def construct_steps(weights, X, Y_GT, model, target_vars):
@@ -786,8 +873,7 @@ def construct_steps(weights, X, Y_GT, model, target_vars):
# X_targ = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
for i in range(FLAGS.num_steps):
- X_old = X
- X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005*scale_fac) * mask
+ X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005 * scale_fac) * mask
energy_noise = model.forward(X, weights, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
@@ -798,19 +884,19 @@ def construct_steps(weights, X, Y_GT, model, target_vars):
X = X - FLAGS.step_lr * x_grad * scale_fac * mask
X = tf.clip_by_value(X, 0, 1)
- if i % n == (n-1):
+ if i % n == (n - 1):
X_mods.append(X)
print("Constructing step {}".format(i))
- target_vars['X_final'] = X
- target_vars['X_mods'] = X_mods
+ target_vars["X_final"] = X
+ target_vars["X_mods"] = X_mods
def nearest_neighbor(dataset, sess, target_vars, logdir):
- X = target_vars['X']
- Y_GT = target_vars['Y_GT']
- x_final = target_vars['X_final']
+ X = target_vars["X"]
+ Y_GT = target_vars["Y_GT"]
+ x_final = target_vars["X_final"]
noise = np.random.uniform(0, 1, size=[10, 32, 32, 3])
# label = np.random.randint(0, 10, size=[10])
@@ -819,24 +905,26 @@ def nearest_neighbor(dataset, sess, target_vars, logdir):
coarse = noise
for i in range(10):
- x_new = sess.run([x_final], {X:coarse, Y_GT:label})[0]
+ x_new = sess.run([x_final], {X: coarse, Y_GT: label})[0]
coarse = x_new
- x_new_dense = x_new.reshape(10, 1, 32*32*3)
- dataset_dense = dataset.reshape(1, 50000, 32*32*3)
+ x_new_dense = x_new.reshape(10, 1, 32 * 32 * 3)
+ dataset_dense = dataset.reshape(1, 50000, 32 * 32 * 3)
diff = np.square(x_new_dense - dataset_dense).sum(axis=2)
diff_idx = np.argsort(diff, axis=1)
- panel = np.zeros((32*10, 32*6, 3))
+ panel = np.zeros((32 * 10, 32 * 6, 3))
dataset_rescale = rescale_im(dataset)
x_new_rescale = rescale_im(x_new)
for i in range(10):
- panel[i*32:i*32+32, :32] = x_new_rescale[i]
+ panel[i * 32 : i * 32 + 32, :32] = x_new_rescale[i]
for j in range(5):
- panel[i*32:i*32+32, 32*j+32:32*j+64] = dataset_rescale[diff_idx[i, j]]
+ panel[i * 32 : i * 32 + 32, 32 * j + 32 : 32 * j + 64] = dataset_rescale[
+ diff_idx[i, j]
+ ]
imsave(osp.join(logdir, "nearest.png"), panel)
@@ -854,12 +942,24 @@ def main():
dataset = Svhn(train=True)
test_dataset = Svhn(train=False)
- if FLAGS.task == 'latent':
+ if FLAGS.task == "latent":
dataset = DSprites()
test_dataset = dataset
- dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True)
- test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=FLAGS.batch_size,
+ num_workers=FLAGS.data_workers,
+ shuffle=True,
+ drop_last=True,
+ )
+ test_dataloader = DataLoader(
+ test_dataset,
+ batch_size=FLAGS.batch_size,
+ num_workers=FLAGS.data_workers,
+ shuffle=True,
+ drop_last=True,
+ )
hidden_dim = 128
@@ -868,17 +968,17 @@ def main():
elif FLAGS.larger_model:
model = ResNet32Larger(num_filters=hidden_dim)
elif FLAGS.wider_model:
- if FLAGS.dataset == 'imagenet':
+ if FLAGS.dataset == "imagenet":
model = ResNet32Wider(num_filters=196, train=False)
else:
model = ResNet32Wider(num_filters=256, train=False)
else:
model = ResNet32(num_filters=hidden_dim)
- if FLAGS.task == 'latent':
+ if FLAGS.task == "latent":
model = DspritesNet()
- weights = model.construct_weights('context_{}'.format(0))
+ weights = model.construct_weights("context_{}".format(0))
total_parameters = 0
for variable in tf.trainable_variables():
@@ -890,90 +990,126 @@ def main():
total_parameters += variable_parameters
print("Model has a total of {} parameters".format(total_parameters))
- config = tf.ConfigProto()
+ tf.ConfigProto()
sess = tf.InteractiveSession()
- if FLAGS.task == 'latent':
- X = tf.placeholder(shape=(None, 64, 64), dtype = tf.float32)
+ if FLAGS.task == "latent":
+ X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
else:
- X = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
+ X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
if FLAGS.dataset == "cifar10":
- Y = tf.placeholder(shape=(None, 10), dtype = tf.float32)
- Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
+ Y = tf.placeholder(shape=(None, 10), dtype=tf.float32)
+ Y_GT = tf.placeholder(shape=(None, 10), dtype=tf.float32)
elif FLAGS.dataset == "imagenet":
- Y = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
- Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
+ Y = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
+ Y_GT = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
- target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT}
+ target_vars = {"X": X, "Y": Y, "Y_GT": Y_GT}
- if FLAGS.task == 'label':
+ if FLAGS.task == "label":
construct_label(weights, X, Y, Y_GT, model, target_vars)
- elif FLAGS.task == 'labelfinetune':
- construct_finetune_label(weights, X, Y, Y_GT, model, target_vars, )
- elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy':
+ elif FLAGS.task == "labelfinetune":
+ construct_finetune_label(
+ weights,
+ X,
+ Y,
+ Y_GT,
+ model,
+ target_vars,
+ )
+ elif FLAGS.task == "energyeval" or FLAGS.task == "mixenergy":
construct_energy(weights, X, Y, Y_GT, model, target_vars)
- elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor':
+ elif (
+ FLAGS.task == "anticorrupt"
+ or FLAGS.task == "boxcorrupt"
+ or FLAGS.task == "crossclass"
+ or FLAGS.task == "cycleclass"
+ or FLAGS.task == "democlass"
+ or FLAGS.task == "nearestneighbor"
+ ):
construct_steps(weights, X, Y_GT, model, target_vars)
- elif FLAGS.task == 'latent':
+ elif FLAGS.task == "latent":
construct_latent(weights, X, Y_GT, model, target_vars)
sess.run(tf.global_variables_initializer())
saver = loader = tf.train.Saver(max_to_keep=10)
- savedir = osp.join('cachedir', FLAGS.exp)
+ savedir = osp.join("cachedir", FLAGS.exp)
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
if not osp.exists(logdir):
os.makedirs(logdir)
initialize()
if FLAGS.resume_iter != -1:
- model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter))
- resume_itr = FLAGS.resume_iter
+ model_file = osp.join(savedir, "model_{}".format(FLAGS.resume_iter))
+ FLAGS.resume_iter
- if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy":
+ if (
+ FLAGS.task == "label"
+ or FLAGS.task == "boxcorrupt"
+ or FLAGS.task == "labelfinetune"
+ or FLAGS.task == "energyeval"
+ or FLAGS.task == "crossclass"
+ or FLAGS.task == "mixenergy"
+ ):
optimistic_restore(sess, model_file)
# saver.restore(sess, model_file)
else:
# optimistic_restore(sess, model_file)
saver.restore(sess, model_file)
- if FLAGS.task == 'label':
+ if FLAGS.task == "label":
if FLAGS.labelgrid:
vals = []
if FLAGS.lnorm == -1:
for i in range(31):
- accuracies = label(dataloader, test_dataloader, target_vars, sess, l1val=i)
+ accuracies = label(
+ dataloader, test_dataloader, target_vars, sess, l1val=i
+ )
vals.append(accuracies)
elif FLAGS.lnorm == 2:
for i in range(0, 100, 5):
- accuracies = label(dataloader, test_dataloader, target_vars, sess, l2val=i)
+ accuracies = label(
+ dataloader, test_dataloader, target_vars, sess, l2val=i
+ )
vals.append(accuracies)
np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals)
else:
label(dataloader, test_dataloader, target_vars, sess)
- elif FLAGS.task == 'labelfinetune':
- labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=FLAGS.lival, l2val=FLAGS.l2val)
- elif FLAGS.task == 'energyeval':
+ elif FLAGS.task == "labelfinetune":
+ labelfinetune(
+ dataloader,
+ test_dataloader,
+ target_vars,
+ sess,
+ savedir,
+ saver,
+ l1val=FLAGS.lival,
+ l2val=FLAGS.l2val,
+ )
+ elif FLAGS.task == "energyeval":
energyeval(dataloader, test_dataloader, target_vars, sess)
- elif FLAGS.task == 'mixenergy':
+ elif FLAGS.task == "mixenergy":
energyevalmix(dataloader, test_dataloader, target_vars, sess)
- elif FLAGS.task == 'anticorrupt':
+ elif FLAGS.task == "anticorrupt":
anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
- elif FLAGS.task == 'boxcorrupt':
+ elif FLAGS.task == "boxcorrupt":
# boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
- boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess)
- elif FLAGS.task == 'crossclass':
+ boxcorrupt(
+ test_dataloader, dataloader, weights, model, target_vars, logdir, sess
+ )
+ elif FLAGS.task == "crossclass":
crossclass(test_dataloader, weights, model, target_vars, logdir, sess)
- elif FLAGS.task == 'cycleclass':
+ elif FLAGS.task == "cycleclass":
cycleclass(test_dataloader, weights, model, target_vars, logdir, sess)
- elif FLAGS.task == 'democlass':
+ elif FLAGS.task == "democlass":
democlass(test_dataloader, weights, model, target_vars, logdir, sess)
- elif FLAGS.task == 'nearestneighbor':
+ elif FLAGS.task == "nearestneighbor":
# print(dir(dataset))
# print(type(dataset))
nearest_neighbor(dataset.data.train_data / 255, sess, target_vars, logdir)
- elif FLAGS.task == 'latent':
+ elif FLAGS.task == "latent":
latent(test_dataloader, weights, model, target_vars, sess)
diff --git a/EBMs/fid.py b/EBMs/fid.py
index 7aee938..f3b5d8f 100644
--- a/EBMs/fid.py
+++ b/EBMs/fid.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
+"""Calculates the Frechet Inception Distance (FID) to evalulate GANs.
The FID metric calculates the distance between two distributions of images.
Typically, we have summary statistics (mean & covariance matrix) of one
@@ -14,28 +14,33 @@ the pool_3 layer of the inception net for generated samples and real world
samples respectivly.
See --help to see further details.
-'''
+"""
from __future__ import absolute_import, division, print_function
-import numpy as np
+
import os
-import gzip, pickle
-import tensorflow as tf
-from scipy.misc import imread
-from scipy import linalg
import pathlib
-import urllib
import tarfile
+import urllib
import warnings
-MODEL_DIR = '/tmp/imagenet'
-DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
+import numpy as np
+import tensorflow as tf
+from scipy import linalg
+from scipy.misc import imread
+
+MODEL_DIR = "/tmp/imagenet"
+DATA_URL = (
+ "http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
+)
pool3 = None
+
class InvalidFIDException(Exception):
pass
-#-------------------------------------------------------------------------------
+
+# -------------------------------------------------------------------------------
def get_fid_score(images, images_gt):
images = np.stack(images, 0)
images_gt = np.stack(images_gt, 0)
@@ -52,34 +57,38 @@ def get_fid_score(images, images_gt):
def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb.
- with tf.gfile.FastGFile( pth, 'rb') as f:
+ with tf.gfile.FastGFile(pth, "rb") as f:
graph_def = tf.GraphDef()
- graph_def.ParseFromString( f.read())
- _ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
-#-------------------------------------------------------------------------------
+ graph_def.ParseFromString(f.read())
+ _ = tf.import_graph_def(graph_def, name="FID_Inception_Net")
+
+
+# -------------------------------------------------------------------------------
# code for handling inception net derived from
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
def _get_inception_layer(sess):
- """Prepares inception net for batched usage and returns pool_3 layer. """
- layername = 'FID_Inception_Net/pool_3:0'
+ """Prepares inception net for batched usage and returns pool_3 layer."""
+ layername = "FID_Inception_Net/pool_3:0"
pool3 = sess.graph.get_tensor_by_name(layername)
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
- shape = [s.value for s in shape]
- new_shape = []
- for j, s in enumerate(shape):
- if s == 1 and j == 0:
- new_shape.append(None)
- else:
- new_shape.append(s)
- o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
+ shape = [s.value for s in shape]
+ new_shape = []
+ for j, s in enumerate(shape):
+ if s == 1 and j == 0:
+ new_shape.append(None)
+ else:
+ new_shape.append(s)
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3
-#-------------------------------------------------------------------------------
+
+
+# -------------------------------------------------------------------------------
def get_activations(images, sess, batch_size=50, verbose=False):
@@ -100,23 +109,27 @@ def get_activations(images, sess, batch_size=50, verbose=False):
# inception_layer = _get_inception_layer(sess)
d0 = images.shape[0]
if batch_size > d0:
- print("warning: batch size is bigger than the data size. setting batch size to data size")
+ print(
+ "warning: batch size is bigger than the data size. setting batch size to data size"
+ )
batch_size = d0
- n_batches = d0//batch_size
- n_used_imgs = n_batches*batch_size
- pred_arr = np.empty((n_used_imgs,2048))
+ n_batches = d0 // batch_size
+ n_used_imgs = n_batches * batch_size
+ pred_arr = np.empty((n_used_imgs, 2048))
for i in range(n_batches):
if verbose:
- print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
- start = i*batch_size
+ print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True)
+ start = i * batch_size
end = start + batch_size
batch = images[start:end]
- pred = sess.run(pool3, {'ExpandDims:0': batch})
- pred_arr[start:end] = pred.reshape(batch_size,-1)
+ pred = sess.run(pool3, {"ExpandDims:0": batch})
+ pred_arr[start:end] = pred.reshape(batch_size, -1)
if verbose:
print(" done")
return pred_arr
-#-------------------------------------------------------------------------------
+
+
+# -------------------------------------------------------------------------------
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
@@ -147,15 +160,22 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
- assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
- assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
+ assert (
+ mu1.shape == mu2.shape
+ ), "Training and test mean vectors have different lengths"
+ assert (
+ sigma1.shape == sigma2.shape
+ ), "Training and test covariances have different dimensions"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
- msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
+ msg = (
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
+ % eps
+ )
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
@@ -170,7 +190,9 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
-#-------------------------------------------------------------------------------
+
+
+# -------------------------------------------------------------------------------
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
@@ -193,47 +215,52 @@ def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
-#-------------------------------------------------------------------------------
-#-------------------------------------------------------------------------------
+# -------------------------------------------------------------------------------
+
+
+# -------------------------------------------------------------------------------
# The following functions aren't needed for calculating the FID
# they're just here to make this module work as a stand-alone script
# for calculating FID scores
-#-------------------------------------------------------------------------------
+# -------------------------------------------------------------------------------
def check_or_download_inception(inception_path):
- ''' Checks if the path to the inception file is valid, or downloads
- the file if it is not present. '''
- INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
+ """Checks if the path to the inception file is valid, or downloads
+ the file if it is not present."""
+ INCEPTION_URL = (
+ "http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
+ )
if inception_path is None:
- inception_path = '/tmp'
+ inception_path = "/tmp"
inception_path = pathlib.Path(inception_path)
- model_file = inception_path / 'classify_image_graph_def.pb'
+ model_file = inception_path / "classify_image_graph_def.pb"
if not model_file.exists():
print("Downloading Inception model")
- from urllib import request
import tarfile
+ from urllib import request
+
fn, _ = request.urlretrieve(INCEPTION_URL)
- with tarfile.open(fn, mode='r') as f:
- f.extract('classify_image_graph_def.pb', str(model_file.parent))
+ with tarfile.open(fn, mode="r") as f:
+ f.extract("classify_image_graph_def.pb", str(model_file.parent))
return str(model_file)
def _handle_path(path, sess):
- if path.endswith('.npz'):
+ if path.endswith(".npz"):
f = np.load(path)
- m, s = f['mu'][:], f['sigma'][:]
+ m, s = f["mu"][:], f["sigma"][:]
f.close()
else:
path = pathlib.Path(path)
- files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
+ files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
m, s = calculate_activation_statistics(x, sess)
return m, s
def calculate_fid_given_paths(paths, inception_path):
- ''' Calculates the FID of two paths. '''
+ """Calculates the FID of two paths."""
inception_path = check_or_download_inception(inception_path)
for p in paths:
@@ -250,43 +277,48 @@ def calculate_fid_given_paths(paths, inception_path):
def _init_inception():
- global pool3
- if not os.path.exists(MODEL_DIR):
- os.makedirs(MODEL_DIR)
- filename = DATA_URL.split('/')[-1]
- filepath = os.path.join(MODEL_DIR, filename)
- if not os.path.exists(filepath):
- def _progress(count, block_size, total_size):
- sys.stdout.write('\r>> Downloading %s %.1f%%' % (
- filename, float(count * block_size) / float(total_size) * 100.0))
- sys.stdout.flush()
- filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
- print()
- statinfo = os.stat(filepath)
- print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
- tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
- with tf.gfile.FastGFile(os.path.join(
- MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- _ = tf.import_graph_def(graph_def, name='')
- # Works with an arbitrary minibatch size.
- with tf.Session() as sess:
- pool3 = sess.graph.get_tensor_by_name('pool_3:0')
- ops = pool3.graph.get_operations()
- for op_idx, op in enumerate(ops):
- for o in op.outputs:
- shape = o.get_shape()
- if shape._dims != []:
- shape = [s.value for s in shape]
- new_shape = []
- for j, s in enumerate(shape):
- if s == 1 and j == 0:
- new_shape.append(None)
- else:
- new_shape.append(s)
- o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
+ global pool3
+ if not os.path.exists(MODEL_DIR):
+ os.makedirs(MODEL_DIR)
+ filename = DATA_URL.split("/")[-1]
+ filepath = os.path.join(MODEL_DIR, filename)
+ if not os.path.exists(filepath):
+
+ def _progress(count, block_size, total_size):
+ sys.stdout.write(
+ "\r>> Downloading %s %.1f%%"
+ % (filename, float(count * block_size) / float(total_size) * 100.0)
+ )
+ sys.stdout.flush()
+
+ filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
+ print()
+ statinfo = os.stat(filepath)
+ print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
+ tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
+ with tf.gfile.FastGFile(
+ os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
+ ) as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ _ = tf.import_graph_def(graph_def, name="")
+ # Works with an arbitrary minibatch size.
+ with tf.Session() as sess:
+ pool3 = sess.graph.get_tensor_by_name("pool_3:0")
+ ops = pool3.graph.get_operations()
+ for op_idx, op in enumerate(ops):
+ for o in op.outputs:
+ shape = o.get_shape()
+ if shape._dims != []:
+ shape = [s.value for s in shape]
+ new_shape = []
+ for j, s in enumerate(shape):
+ if s == 1 and j == 0:
+ new_shape.append(None)
+ else:
+ new_shape.append(s)
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
if pool3 is None:
- _init_inception()
+ _init_inception()
diff --git a/EBMs/hmc.py b/EBMs/hmc.py
index 68821c9..fd5cf11 100644
--- a/EBMs/hmc.py
+++ b/EBMs/hmc.py
@@ -1,11 +1,11 @@
import tensorflow as tf
-import numpy as np
-
from tensorflow.python.platform import flags
-flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes')
+
+flags.DEFINE_bool("proposal_debug", False, "Print hmc acceptance raes")
FLAGS = flags.FLAGS
+
def kinetic_energy(velocity):
"""Kinetic energy of the current velocity (assuming a standard Gaussian)
(x dot x) / 2
@@ -21,6 +21,7 @@ def kinetic_energy(velocity):
"""
return 0.5 * tf.square(velocity)
+
def hamiltonian(position, velocity, energy_function):
"""Computes the Hamiltonian of the current position, velocity pair
@@ -44,13 +45,12 @@ def hamiltonian(position, velocity, energy_function):
"""
batch_size = tf.shape(velocity)[0]
kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1))
- return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1])
+ return tf.squeeze(energy_function(position)) + tf.reduce_sum(
+ kinetic_energy_flat, axis=[1]
+ )
-def leapfrog_step(x0,
- v0,
- neg_log_posterior,
- step_size,
- num_steps):
+
+def leapfrog_step(x0, v0, neg_log_posterior, step_size, num_steps):
# Start by updating the velocity a half-step
v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0]
@@ -83,10 +83,8 @@ def leapfrog_step(x0,
# return new proposal state
return x, v
-def hmc(initial_x,
- step_size,
- num_steps,
- neg_log_posterior):
+
+def hmc(initial_x, step_size, num_steps, neg_log_posterior):
"""Summary
Parameters
@@ -107,11 +105,13 @@ def hmc(initial_x,
"""
v0 = tf.random_normal(tf.shape(initial_x))
- x, v = leapfrog_step(initial_x,
- v0,
- step_size=step_size,
- num_steps=num_steps,
- neg_log_posterior=neg_log_posterior)
+ x, v = leapfrog_step(
+ initial_x,
+ v0,
+ step_size=step_size,
+ num_steps=num_steps,
+ neg_log_posterior=neg_log_posterior,
+ )
orig = hamiltonian(initial_x, v0, neg_log_posterior)
current = hamiltonian(x, v, neg_log_posterior)
@@ -119,10 +119,12 @@ def hmc(initial_x,
prob_accept = tf.exp(orig - current)
if FLAGS.proposal_debug:
- prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))])
+ prob_accept = tf.Print(
+ prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))]
+ )
uniform = tf.random_uniform(tf.shape(prob_accept))
- keep_mask = (prob_accept > uniform)
+ keep_mask = prob_accept > uniform
# print(keep_mask.get_shape())
x_new = tf.where(keep_mask, x, initial_x)
diff --git a/EBMs/imagenet_demo.py b/EBMs/imagenet_demo.py
index 9d79395..9d6f1e4 100644
--- a/EBMs/imagenet_demo.py
+++ b/EBMs/imagenet_demo.py
@@ -1,24 +1,28 @@
-from models import ResNet128
-import numpy as np
import os.path as osp
-from tensorflow.python.platform import flags
-import tensorflow as tf
+
import imageio
-from utils import optimistic_restore
+import numpy as np
+import tensorflow as tf
+from models import ResNet128
+from tensorflow.python.platform import flags
-
-flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
-flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling')
-flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics')
-flags.DEFINE_integer('batch_size', 16, 'number of steps to run')
-flags.DEFINE_string('exp', 'default', 'name of experiments')
-flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
-flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model')
-flags.DEFINE_bool('cclass', True, 'conditional models')
-flags.DEFINE_bool('use_attention', False, 'using attention')
+flags.DEFINE_string(
+ "logdir", "cachedir", "location where log of experiments will be stored"
+)
+flags.DEFINE_integer("num_steps", 200, "num of steps for conditional imagenet sampling")
+flags.DEFINE_float("step_lr", 180.0, "step size for Langevin dynamics")
+flags.DEFINE_integer("batch_size", 16, "number of steps to run")
+flags.DEFINE_string("exp", "default", "name of experiments")
+flags.DEFINE_integer("resume_iter", -1, "iteration to resume training from")
+flags.DEFINE_bool(
+ "spec_norm", True, "whether to use spectral normalization in weights in a model"
+)
+flags.DEFINE_bool("cclass", True, "conditional models")
+flags.DEFINE_bool("use_attention", False, "using attention")
FLAGS = flags.FLAGS
+
def rescale_im(im):
return np.clip(im * 256, 0, 255).astype(np.uint8)
@@ -32,12 +36,11 @@ if __name__ == "__main__":
weights = model.construct_weights("context_0")
x_mod = X_NOISE
- x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
- mean=0.0,
- stddev=0.005)
+ x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
- energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL,
- reuse=True, stop_at_grad=False, stop_batch=True)
+ energy_noise = energy_start = model.forward(
+ x_mod, weights, label=LABEL, reuse=True, stop_at_grad=False, stop_batch=True
+ )
x_grad = tf.gradients(energy_noise, [x_mod])[0]
energy_noise_old = energy_noise
@@ -54,20 +57,23 @@ if __name__ == "__main__":
saver = loader = tf.train.Saver()
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
- model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
+ model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
saver.restore(sess, model_file)
lx = np.random.permutation(1000)[:16]
ims = []
- # What to initialize sampling with.
+ # What to initialize sampling with.
x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3))
labels = np.eye(1000)[lx]
for i in range(FLAGS.num_steps):
- e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels})
- ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3)))
-
- imageio.mimwrite('sample.gif', ims)
-
+ e, x_mod = sess.run([energy_noise, x_output], {X_NOISE: x_mod, LABEL: labels})
+ ims.append(
+ rescale_im(x_mod)
+ .reshape((4, 4, 128, 128, 3))
+ .transpose((0, 2, 1, 3, 4))
+ .reshape((512, 512, 3))
+ )
+ imageio.mimwrite("sample.gif", ims)
diff --git a/EBMs/imagenet_preprocessing.py b/EBMs/imagenet_preprocessing.py
index cbfde0c..9648253 100644
--- a/EBMs/imagenet_preprocessing.py
+++ b/EBMs/imagenet_preprocessing.py
@@ -13,14 +13,11 @@
# limitations under the License.
# ==============================================================================
-"""Image pre-processing utilities.
-"""
+"""Image pre-processing utilities."""
import tensorflow as tf
+IMAGE_DEPTH = 3 # color images
-IMAGE_DEPTH = 3 # color images
-
-import tensorflow as tf
# _R_MEAN = 123.68
# _G_MEAN = 116.78
@@ -35,303 +32,318 @@ _RESIZE_MIN = 128
def _decode_crop_and_flip(image_buffer, bbox, num_channels):
- """Crops the given image to a random part of the image, and randomly flips.
+ """Crops the given image to a random part of the image, and randomly flips.
- We use the fused decode_and_crop op, which performs better than the two ops
- used separately in series, but note that this requires that the image be
- passed in as an un-decoded string Tensor.
+ We use the fused decode_and_crop op, which performs better than the two ops
+ used separately in series, but note that this requires that the image be
+ passed in as an un-decoded string Tensor.
- Args:
- image_buffer: scalar string Tensor representing the raw JPEG image buffer.
- bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
- where each coordinate is [0, 1) and the coordinates are arranged as
- [ymin, xmin, ymax, xmax].
- num_channels: Integer depth of the image buffer for decoding.
+ Args:
+ image_buffer: scalar string Tensor representing the raw JPEG image buffer.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as
+ [ymin, xmin, ymax, xmax].
+ num_channels: Integer depth of the image buffer for decoding.
- Returns:
- 3-D tensor with cropped image.
+ Returns:
+ 3-D tensor with cropped image.
- """
- # A large fraction of image datasets contain a human-annotated bounding box
- # delineating the region of the image containing the object of interest. We
- # choose to create a new bounding box for the object which is a randomly
- # distorted version of the human-annotated bounding box that obeys an
- # allowed range of aspect ratios, sizes and overlap with the human-annotated
- # bounding box. If no box is supplied, then we assume the bounding box is
- # the entire image.
- sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
- tf.image.extract_jpeg_shape(image_buffer),
- bounding_boxes=bbox,
- min_object_covered=0.1,
- aspect_ratio_range=[0.75, 1.33],
- area_range=[0.05, 1.0],
- max_attempts=100,
- use_image_if_no_bounding_boxes=True)
- bbox_begin, bbox_size, _ = sample_distorted_bounding_box
+ """
+ # A large fraction of image datasets contain a human-annotated bounding box
+ # delineating the region of the image containing the object of interest. We
+ # choose to create a new bounding box for the object which is a randomly
+ # distorted version of the human-annotated bounding box that obeys an
+ # allowed range of aspect ratios, sizes and overlap with the human-annotated
+ # bounding box. If no box is supplied, then we assume the bounding box is
+ # the entire image.
+ sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
+ tf.image.extract_jpeg_shape(image_buffer),
+ bounding_boxes=bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=[0.75, 1.33],
+ area_range=[0.05, 1.0],
+ max_attempts=100,
+ use_image_if_no_bounding_boxes=True,
+ )
+ bbox_begin, bbox_size, _ = sample_distorted_bounding_box
- # Reassemble the bounding box in the format the crop op requires.
- offset_y, offset_x, _ = tf.unstack(bbox_begin)
- target_height, target_width, _ = tf.unstack(bbox_size)
- crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
+ # Reassemble the bounding box in the format the crop op requires.
+ offset_y, offset_x, _ = tf.unstack(bbox_begin)
+ target_height, target_width, _ = tf.unstack(bbox_size)
+ crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
- # Use the fused decode and crop op here, which is faster than each in series.
- cropped = tf.image.decode_and_crop_jpeg(
- image_buffer, crop_window, channels=num_channels)
+ # Use the fused decode and crop op here, which is faster than each in
+ # series.
+ cropped = tf.image.decode_and_crop_jpeg(
+ image_buffer, crop_window, channels=num_channels
+ )
- # Flip to add a little more random distortion in.
- cropped = tf.image.random_flip_left_right(cropped)
- return cropped
+ # Flip to add a little more random distortion in.
+ cropped = tf.image.random_flip_left_right(cropped)
+ return cropped
def _central_crop(image, crop_height, crop_width):
- """Performs central crops of the given image list.
+ """Performs central crops of the given image list.
- Args:
- image: a 3-D image tensor
- crop_height: the height of the image following the crop.
- crop_width: the width of the image following the crop.
+ Args:
+ image: a 3-D image tensor
+ crop_height: the height of the image following the crop.
+ crop_width: the width of the image following the crop.
- Returns:
- 3-D tensor with cropped image.
- """
- shape = tf.shape(input=image)
- height, width = shape[0], shape[1]
+ Returns:
+ 3-D tensor with cropped image.
+ """
+ shape = tf.shape(input=image)
+ height, width = shape[0], shape[1]
- amount_to_be_cropped_h = (height - crop_height)
- crop_top = amount_to_be_cropped_h // 2
- amount_to_be_cropped_w = (width - crop_width)
- crop_left = amount_to_be_cropped_w // 2
- return tf.slice(
- image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
+ amount_to_be_cropped_h = height - crop_height
+ crop_top = amount_to_be_cropped_h // 2
+ amount_to_be_cropped_w = width - crop_width
+ crop_left = amount_to_be_cropped_w // 2
+ return tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
def _mean_image_subtraction(image, means, num_channels):
- """Subtracts the given means from each image channel.
+ """Subtracts the given means from each image channel.
- For example:
- means = [123.68, 116.779, 103.939]
- image = _mean_image_subtraction(image, means)
+ For example:
+ means = [123.68, 116.779, 103.939]
+ image = _mean_image_subtraction(image, means)
- Note that the rank of `image` must be known.
+ Note that the rank of `image` must be known.
- Args:
- image: a tensor of size [height, width, C].
- means: a C-vector of values to subtract from each channel.
- num_channels: number of color channels in the image that will be distorted.
+ Args:
+ image: a tensor of size [height, width, C].
+ means: a C-vector of values to subtract from each channel.
+ num_channels: number of color channels in the image that will be distorted.
- Returns:
- the centered image.
+ Returns:
+ the centered image.
- Raises:
- ValueError: If the rank of `image` is unknown, if `image` has a rank other
- than three or if the number of channels in `image` doesn't match the
- number of values in `means`.
- """
- if image.get_shape().ndims != 3:
- raise ValueError('Input must be of size [height, width, C>0]')
+ Raises:
+ ValueError: If the rank of `image` is unknown, if `image` has a rank other
+ than three or if the number of channels in `image` doesn't match the
+ number of values in `means`.
+ """
+ if image.get_shape().ndims != 3:
+ raise ValueError("Input must be of size [height, width, C>0]")
- if len(means) != num_channels:
- raise ValueError('len(means) must match the number of channels')
+ if len(means) != num_channels:
+ raise ValueError("len(means) must match the number of channels")
- # We have a 1-D tensor of means; convert to 3-D.
- means = tf.expand_dims(tf.expand_dims(means, 0), 0)
+ # We have a 1-D tensor of means; convert to 3-D.
+ means = tf.expand_dims(tf.expand_dims(means, 0), 0)
- return image - means
+ return image - means
def _smallest_size_at_least(height, width, resize_min):
- """Computes new shape with the smallest side equal to `smallest_side`.
+ """Computes new shape with the smallest side equal to `smallest_side`.
- Computes new shape with the smallest side equal to `smallest_side` while
- preserving the original aspect ratio.
+ Computes new shape with the smallest side equal to `smallest_side` while
+ preserving the original aspect ratio.
- Args:
- height: an int32 scalar tensor indicating the current height.
- width: an int32 scalar tensor indicating the current width.
- resize_min: A python integer or scalar `Tensor` indicating the size of
- the smallest side after resize.
+ Args:
+ height: an int32 scalar tensor indicating the current height.
+ width: an int32 scalar tensor indicating the current width.
+ resize_min: A python integer or scalar `Tensor` indicating the size of
+ the smallest side after resize.
- Returns:
- new_height: an int32 scalar tensor indicating the new height.
- new_width: an int32 scalar tensor indicating the new width.
- """
- resize_min = tf.cast(resize_min, tf.float32)
+ Returns:
+ new_height: an int32 scalar tensor indicating the new height.
+ new_width: an int32 scalar tensor indicating the new width.
+ """
+ resize_min = tf.cast(resize_min, tf.float32)
- # Convert to floats to make subsequent calculations go smoothly.
- height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
+ # Convert to floats to make subsequent calculations go smoothly.
+ height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
- smaller_dim = tf.minimum(height, width)
- scale_ratio = resize_min / smaller_dim
+ smaller_dim = tf.minimum(height, width)
+ scale_ratio = resize_min / smaller_dim
- # Convert back to ints to make heights and widths that TF ops will accept.
- new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
- new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
+ # Convert back to ints to make heights and widths that TF ops will accept.
+ new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
+ new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
- return new_height, new_width
+ return new_height, new_width
def _aspect_preserving_resize(image, resize_min):
- """Resize images preserving the original aspect ratio.
+ """Resize images preserving the original aspect ratio.
- Args:
- image: A 3-D image `Tensor`.
- resize_min: A python integer or scalar `Tensor` indicating the size of
- the smallest side after resize.
+ Args:
+ image: A 3-D image `Tensor`.
+ resize_min: A python integer or scalar `Tensor` indicating the size of
+ the smallest side after resize.
- Returns:
- resized_image: A 3-D tensor containing the resized image.
- """
- shape = tf.shape(input=image)
- height, width = shape[0], shape[1]
+ Returns:
+ resized_image: A 3-D tensor containing the resized image.
+ """
+ shape = tf.shape(input=image)
+ height, width = shape[0], shape[1]
- new_height, new_width = _smallest_size_at_least(height, width, resize_min)
+ new_height, new_width = _smallest_size_at_least(height, width, resize_min)
- return _resize_image(image, new_height, new_width)
+ return _resize_image(image, new_height, new_width)
def _resize_image(image, height, width):
- """Simple wrapper around tf.resize_images.
+ """Simple wrapper around tf.resize_images.
- This is primarily to make sure we use the same `ResizeMethod` and other
- details each time.
+ This is primarily to make sure we use the same `ResizeMethod` and other
+ details each time.
- Args:
- image: A 3-D image `Tensor`.
- height: The target height for the resized image.
- width: The target width for the resized image.
+ Args:
+ image: A 3-D image `Tensor`.
+ height: The target height for the resized image.
+ width: The target width for the resized image.
- Returns:
- resized_image: A 3-D tensor containing the resized image. The first two
- dimensions have the shape [height, width].
- """
- return tf.image.resize_images(
- image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
- align_corners=False)
+ Returns:
+ resized_image: A 3-D tensor containing the resized image. The first two
+ dimensions have the shape [height, width].
+ """
+ return tf.image.resize_images(
+ image,
+ [height, width],
+ method=tf.image.ResizeMethod.BILINEAR,
+ align_corners=False,
+ )
-def preprocess_image(image_buffer, bbox, output_height, output_width,
- num_channels, is_training=False):
- """Preprocesses the given image.
+def preprocess_image(
+ image_buffer, bbox, output_height, output_width, num_channels, is_training=False
+):
+ """Preprocesses the given image.
- Preprocessing includes decoding, cropping, and resizing for both training
- and eval images. Training preprocessing, however, introduces some random
- distortion of the image to improve accuracy.
+ Preprocessing includes decoding, cropping, and resizing for both training
+ and eval images. Training preprocessing, however, introduces some random
+ distortion of the image to improve accuracy.
- Args:
- image_buffer: scalar string Tensor representing the raw JPEG image buffer.
- bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
- where each coordinate is [0, 1) and the coordinates are arranged as
- [ymin, xmin, ymax, xmax].
- output_height: The height of the image after preprocessing.
- output_width: The width of the image after preprocessing.
- num_channels: Integer depth of the image buffer for decoding.
- is_training: `True` if we're preprocessing the image for training and
- `False` otherwise.
+ Args:
+ image_buffer: scalar string Tensor representing the raw JPEG image buffer.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as
+ [ymin, xmin, ymax, xmax].
+ output_height: The height of the image after preprocessing.
+ output_width: The width of the image after preprocessing.
+ num_channels: Integer depth of the image buffer for decoding.
+ is_training: `True` if we're preprocessing the image for training and
+ `False` otherwise.
- Returns:
- A preprocessed image.
- """
- if is_training:
- # For training, we want to randomize some of the distortions.
- image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
- image = _resize_image(image, output_height, output_width)
- else:
- # For validation, we want to decode, resize, then just crop the middle.
- image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
- image = _aspect_preserving_resize(image, _RESIZE_MIN)
- print(image)
- image = _central_crop(image, output_height, output_width)
+ Returns:
+ A preprocessed image.
+ """
+ if is_training:
+ # For training, we want to randomize some of the distortions.
+ image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
+ image = _resize_image(image, output_height, output_width)
+ else:
+ # For validation, we want to decode, resize, then just crop the middle.
+ image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
+ image = _aspect_preserving_resize(image, _RESIZE_MIN)
+ print(image)
+ image = _central_crop(image, output_height, output_width)
- image.set_shape([output_height, output_width, num_channels])
+ image.set_shape([output_height, output_width, num_channels])
- return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
+ return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
def parse_example_proto(example_serialized):
- """Parses an Example proto containing a training example of an image.
+ """Parses an Example proto containing a training example of an image.
- The output of the build_image_data.py image preprocessing script is a dataset
- containing serialized Example protocol buffers. Each Example proto contains
- the following fields:
+ The output of the build_image_data.py image preprocessing script is a dataset
+ containing serialized Example protocol buffers. Each Example proto contains
+ the following fields:
- image/height: 462
- image/width: 581
- image/colorspace: 'RGB'
- image/channels: 3
- image/class/label: 615
- image/class/synset: 'n03623198'
- image/class/text: 'knee pad'
- image/object/bbox/xmin: 0.1
- image/object/bbox/xmax: 0.9
- image/object/bbox/ymin: 0.2
- image/object/bbox/ymax: 0.6
- image/object/bbox/label: 615
- image/format: 'JPEG'
- image/filename: 'ILSVRC2012_val_00041207.JPEG'
- image/encoded:
+ image/height: 462
+ image/width: 581
+ image/colorspace: 'RGB'
+ image/channels: 3
+ image/class/label: 615
+ image/class/synset: 'n03623198'
+ image/class/text: 'knee pad'
+ image/object/bbox/xmin: 0.1
+ image/object/bbox/xmax: 0.9
+ image/object/bbox/ymin: 0.2
+ image/object/bbox/ymax: 0.6
+ image/object/bbox/label: 615
+ image/format: 'JPEG'
+ image/filename: 'ILSVRC2012_val_00041207.JPEG'
+ image/encoded:
- Args:
- example_serialized: scalar Tensor tf.string containing a serialized
- Example protocol buffer.
+ Args:
+ example_serialized: scalar Tensor tf.string containing a serialized
+ Example protocol buffer.
- Returns:
- image_buffer: Tensor tf.string containing the contents of a JPEG file.
- label: Tensor tf.int32 containing the label.
- bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
- where each coordinate is [0, 1) and the coordinates are arranged as
- [ymin, xmin, ymax, xmax].
- text: Tensor tf.string containing the human-readable label.
- """
- # Dense features in Example proto.
- feature_map = {
- 'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
- default_value=''),
- 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
- default_value=-1),
- 'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
- default_value=''),
- }
- sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
- # Sparse features in Example proto.
- feature_map.update(
- {k: sparse_float32 for k in ['image/object/bbox/xmin',
- 'image/object/bbox/ymin',
- 'image/object/bbox/xmax',
- 'image/object/bbox/ymax']})
+ Returns:
+ image_buffer: Tensor tf.string containing the contents of a JPEG file.
+ label: Tensor tf.int32 containing the label.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as
+ [ymin, xmin, ymax, xmax].
+ text: Tensor tf.string containing the human-readable label.
+ """
+ # Dense features in Example proto.
+ feature_map = {
+ "image/encoded": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
+ "image/class/label": tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1),
+ "image/class/text": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
+ }
+ sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
+ # Sparse features in Example proto.
+ feature_map.update(
+ {
+ k: sparse_float32
+ for k in [
+ "image/object/bbox/xmin",
+ "image/object/bbox/ymin",
+ "image/object/bbox/xmax",
+ "image/object/bbox/ymax",
+ ]
+ }
+ )
- features = tf.parse_single_example(example_serialized, feature_map)
- label = tf.cast(features['image/class/label'], dtype=tf.int32)
+ features = tf.parse_single_example(example_serialized, feature_map)
+ label = tf.cast(features["image/class/label"], dtype=tf.int32)
- xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
- ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
- xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
- ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
+ xmin = tf.expand_dims(features["image/object/bbox/xmin"].values, 0)
+ ymin = tf.expand_dims(features["image/object/bbox/ymin"].values, 0)
+ xmax = tf.expand_dims(features["image/object/bbox/xmax"].values, 0)
+ ymax = tf.expand_dims(features["image/object/bbox/ymax"].values, 0)
- # Note that we impose an ordering of (y, x) just to make life difficult.
- bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
+ # Note that we impose an ordering of (y, x) just to make life difficult.
+ bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
- # Force the variable number of bounding boxes into the shape
- # [1, num_boxes, coords].
- bbox = tf.expand_dims(bbox, 0)
- bbox = tf.transpose(bbox, [0, 2, 1])
+ # Force the variable number of bounding boxes into the shape
+ # [1, num_boxes, coords].
+ bbox = tf.expand_dims(bbox, 0)
+ bbox = tf.transpose(bbox, [0, 2, 1])
- return features['image/encoded'], label, bbox, features['image/class/text']
+ return features["image/encoded"], label, bbox, features["image/class/text"]
class ImagenetPreprocessor:
- def __init__(self, image_size, dtype, train):
- self.image_size = image_size
- self.dtype = dtype
- self.train = train
+ def __init__(self, image_size, dtype, train):
+ self.image_size = image_size
+ self.dtype = dtype
+ self.train = train
- def preprocess(self, image_buffer, bbox):
- # pylint: disable=g-import-not-at-top
- image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train)
- return tf.cast(image, self.dtype)
-
- def parse_and_preprocess(self, value):
- image_buffer, label_index, bbox, _ = parse_example_proto(value)
- image = self.preprocess(image_buffer, bbox)
- image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
- return label_index, image
+ def preprocess(self, image_buffer, bbox):
+ # pylint: disable=g-import-not-at-top
+ image = preprocess_image(
+ image_buffer,
+ bbox,
+ self.image_size,
+ self.image_size,
+ IMAGE_DEPTH,
+ is_training=self.train,
+ )
+ return tf.cast(image, self.dtype)
+ def parse_and_preprocess(self, value):
+ image_buffer, label_index, bbox, _ = parse_example_proto(value)
+ image = self.preprocess(image_buffer, bbox)
+ image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
+ return label_index, image
diff --git a/EBMs/inception.py b/EBMs/inception.py
index 6c76f3b..5bb14aa 100644
--- a/EBMs/inception.py
+++ b/EBMs/inception.py
@@ -1,105 +1,112 @@
-# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+# Code derived from
+# tensorflow/tensorflow/models/image/imagenet/classify_image.py
+from __future__ import absolute_import, division, print_function
+import math
import os.path
import sys
import tarfile
-import numpy as np
-from six.moves import urllib
-import tensorflow as tf
-import glob
-import scipy.misc
-import math
-import sys
-
import horovod.tensorflow as hvd
+import numpy as np
+import tensorflow as tf
+from six.moves import urllib
-MODEL_DIR = '/tmp/imagenet'
-DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
+MODEL_DIR = "/tmp/imagenet"
+DATA_URL = (
+ "http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
+)
softmax = None
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())
sess = tf.Session(config=config)
-# Call this function with list of images. Each of elements should be a
+
+# Call this function with list of images. Each of elements should be a
# numpy array with values ranging from 0 to 255.
def get_inception_score(images, splits=10):
- # For convenience
- if len(images[0].shape) != 3:
- return 0, 0
+ # For convenience
+ if len(images[0].shape) != 3:
+ return 0, 0
+
+ # Bypassing all the assertions so that we don't end prematuraly'
+ # assert(type(images) == list)
+ # assert(type(images[0]) == np.ndarray)
+ # assert(len(images[0].shape) == 3)
+ # assert(np.max(images[0]) > 10)
+ # assert(np.min(images[0]) >= 0.0)
+ inps = []
+ for img in images:
+ img = img.astype(np.float32)
+ inps.append(np.expand_dims(img, 0))
+ bs = 1
+ preds = []
+ n_batches = int(math.ceil(float(len(inps)) / float(bs)))
+ for i in range(n_batches):
+ sys.stdout.write(".")
+ sys.stdout.flush()
+ inp = inps[(i * bs) : min((i + 1) * bs, len(inps))]
+ inp = np.concatenate(inp, 0)
+ pred = sess.run(softmax, {"ExpandDims:0": inp})
+ preds.append(pred)
+ preds = np.concatenate(preds, 0)
+ scores = []
+ for i in range(splits):
+ part = preds[
+ (i * preds.shape[0] // splits) : ((i + 1) * preds.shape[0] // splits), :
+ ]
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
+ kl = np.mean(np.sum(kl, 1))
+ scores.append(np.exp(kl))
+ return np.mean(scores), np.std(scores)
- # Bypassing all the assertions so that we don't end prematuraly'
- # assert(type(images) == list)
- # assert(type(images[0]) == np.ndarray)
- # assert(len(images[0].shape) == 3)
- # assert(np.max(images[0]) > 10)
- # assert(np.min(images[0]) >= 0.0)
- inps = []
- for img in images:
- img = img.astype(np.float32)
- inps.append(np.expand_dims(img, 0))
- bs = 1
- preds = []
- n_batches = int(math.ceil(float(len(inps)) / float(bs)))
- for i in range(n_batches):
- sys.stdout.write(".")
- sys.stdout.flush()
- inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
- inp = np.concatenate(inp, 0)
- pred = sess.run(softmax, {'ExpandDims:0': inp})
- preds.append(pred)
- preds = np.concatenate(preds, 0)
- scores = []
- for i in range(splits):
- part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
- kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
- kl = np.mean(np.sum(kl, 1))
- scores.append(np.exp(kl))
- return np.mean(scores), np.std(scores)
# This function is called automatically.
def _init_inception():
- global softmax
- if not os.path.exists(MODEL_DIR):
- os.makedirs(MODEL_DIR)
- filename = DATA_URL.split('/')[-1]
- filepath = os.path.join(MODEL_DIR, filename)
- if not os.path.exists(filepath):
- def _progress(count, block_size, total_size):
- sys.stdout.write('\r>> Downloading %s %.1f%%' % (
- filename, float(count * block_size) / float(total_size) * 100.0))
- sys.stdout.flush()
- filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
- print()
- statinfo = os.stat(filepath)
- print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
- tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
- with tf.gfile.FastGFile(os.path.join(
- MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- _ = tf.import_graph_def(graph_def, name='')
- # Works with an arbitrary minibatch size.
- pool3 = sess.graph.get_tensor_by_name('pool_3:0')
- ops = pool3.graph.get_operations()
- for op_idx, op in enumerate(ops):
- for o in op.outputs:
- shape = o.get_shape()
- shape = [s.value for s in shape]
- new_shape = []
- for j, s in enumerate(shape):
- if s == 1 and j == 0:
- new_shape.append(None)
- else:
- new_shape.append(s)
- o.set_shape(tf.TensorShape(new_shape))
- w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
- logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
- softmax = tf.nn.softmax(logits)
+ global softmax
+ if not os.path.exists(MODEL_DIR):
+ os.makedirs(MODEL_DIR)
+ filename = DATA_URL.split("/")[-1]
+ filepath = os.path.join(MODEL_DIR, filename)
+ if not os.path.exists(filepath):
+
+ def _progress(count, block_size, total_size):
+ sys.stdout.write(
+ "\r>> Downloading %s %.1f%%"
+ % (filename, float(count * block_size) / float(total_size) * 100.0)
+ )
+ sys.stdout.flush()
+
+ filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
+ print()
+ statinfo = os.stat(filepath)
+ print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
+ tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
+ with tf.gfile.FastGFile(
+ os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
+ ) as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ _ = tf.import_graph_def(graph_def, name="")
+ # Works with an arbitrary minibatch size.
+ pool3 = sess.graph.get_tensor_by_name("pool_3:0")
+ ops = pool3.graph.get_operations()
+ for op_idx, op in enumerate(ops):
+ for o in op.outputs:
+ shape = o.get_shape()
+ shape = [s.value for s in shape]
+ new_shape = []
+ for j, s in enumerate(shape):
+ if s == 1 and j == 0:
+ new_shape.append(None)
+ else:
+ new_shape.append(s)
+ o.set_shape(tf.TensorShape(new_shape))
+ w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
+ logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
+ softmax = tf.nn.softmax(logits)
+
if softmax is None:
- _init_inception()
+ _init_inception()
diff --git a/EBMs/models.py b/EBMs/models.py
index 7099528..70c2d43 100644
--- a/EBMs/models.py
+++ b/EBMs/models.py
@@ -1,10 +1,21 @@
+import numpy as np
import tensorflow as tf
from tensorflow.python.platform import flags
-import numpy as np
-from utils import conv_block, get_weight, attention, conv_cond_concat, init_conv_weight, init_attention_weight, init_res_weight, smart_res_block, smart_res_block_optim, init_convt_weight
-from utils import init_fc_weight, smart_conv_block, smart_fc_block, smart_atten_block, groupsort, smart_convt_block, swish
+from utils import (
+ conv_cond_concat,
+ groupsort,
+ init_attention_weight,
+ init_conv_weight,
+ init_fc_weight,
+ init_res_weight,
+ smart_atten_block,
+ smart_conv_block,
+ smart_fc_block,
+ smart_res_block,
+ swish,
+)
-flags.DEFINE_bool('swish_act', False, 'use the swish activation for dsprites')
+flags.DEFINE_bool("swish_act", False, "use the swish activation for dsprites")
FLAGS = flags.FLAGS
@@ -21,22 +32,37 @@ class MnistNet(object):
else:
self.label_size = 0
- def construct_weights(self, scope=''):
+ def construct_weights(self, scope=""):
weights = {}
dtype = tf.float32
- conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
- fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
+ conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
+ fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
classes = 1
with tf.variable_scope(scope):
- init_conv_weight(weights, 'c1_pre', 3, 1, 64)
- init_conv_weight(weights, 'c1', 4, 64, self.dim_hidden, classes=classes)
- init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_fc_weight(weights, 'fc_dense', 4*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True)
- init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False)
+ init_conv_weight(weights, "c1_pre", 3, 1, 64)
+ init_conv_weight(weights, "c1", 4, 64, self.dim_hidden, classes=classes)
+ init_conv_weight(
+ weights, "c2", 4, self.dim_hidden, 2 * self.dim_hidden, classes=classes
+ )
+ init_conv_weight(
+ weights,
+ "c3",
+ 4,
+ 2 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_fc_weight(
+ weights,
+ "fc_dense",
+ 4 * 4 * 4 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ spec_norm=True,
+ )
+ init_fc_weight(weights, "fc5", 2 * self.dim_hidden, 1, spec_norm=False)
if FLAGS.cclass:
self.label_size = 10
@@ -44,8 +70,10 @@ class MnistNet(object):
self.label_size = 0
return weights
- def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, **kwargs):
- channels = self.channels
+ def forward(
+ self, inp, weights, reuse=False, scope="", stop_grad=False, label=None, **kwargs
+ ):
+ self.channels
weights = weights.copy()
inp = tf.reshape(inp, (tf.shape(inp)[0], 28, 28, 1))
@@ -56,7 +84,7 @@ class MnistNet(object):
if stop_grad:
for k, v in weights.items():
- if type(v) == dict:
+ if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
@@ -65,24 +93,67 @@ class MnistNet(object):
weights[k] = tf.stop_gradient(v)
if FLAGS.cclass:
- label_d = tf.reshape(label, shape=(tf.shape(label)[0], 1, 1, self.label_size))
+ label_d = tf.reshape(
+ label, shape=(tf.shape(label)[0], 1, 1, self.label_size)
+ )
inp = conv_cond_concat(inp, label_d)
- h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
- h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act)
- h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act)
- h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=False, extra_bias=False, activation=act)
+ h1 = smart_conv_block(
+ inp, weights, reuse, "c1_pre", use_stride=False, activation=act
+ )
+ h2 = smart_conv_block(
+ h1,
+ weights,
+ reuse,
+ "c1",
+ use_stride=True,
+ downsample=True,
+ label=label,
+ extra_bias=False,
+ activation=act,
+ )
+ h3 = smart_conv_block(
+ h2,
+ weights,
+ reuse,
+ "c2",
+ use_stride=True,
+ downsample=True,
+ label=label,
+ extra_bias=False,
+ activation=act,
+ )
+ h4 = smart_conv_block(
+ h3,
+ weights,
+ reuse,
+ "c3",
+ use_stride=True,
+ downsample=True,
+ label=label,
+ use_scale=False,
+ extra_bias=False,
+ activation=act,
+ )
h5 = tf.reshape(h4, [-1, np.prod([int(dim) for dim in h4.get_shape()[1:]])])
- h6 = act(smart_fc_block(h5, weights, reuse, 'fc_dense'))
- hidden6 = smart_fc_block(h6, weights, reuse, 'fc5')
+ h6 = act(smart_fc_block(h5, weights, reuse, "fc_dense"))
+ hidden6 = smart_fc_block(h6, weights, reuse, "fc5")
return hidden6
class DspritesNet(object):
- def __init__(self, num_channels=1, num_filters=64, cond_size=False, cond_shape=False, cond_pos=False,
- cond_rot=False, label_size=1):
+ def __init__(
+ self,
+ num_channels=1,
+ num_filters=64,
+ cond_size=False,
+ cond_shape=False,
+ cond_pos=False,
+ cond_rot=False,
+ label_size=1,
+ ):
self.channels = num_channels
self.dim_hidden = num_filters
@@ -104,7 +175,7 @@ class DspritesNet(object):
if FLAGS.drot_only:
self.label_size = 2
- except:
+ except BaseException:
pass
if cond_size:
@@ -123,28 +194,60 @@ class DspritesNet(object):
self.cond_shape = cond_shape
self.cond_pos = cond_pos
- def construct_weights(self, scope=''):
+ def construct_weights(self, scope=""):
weights = {}
dtype = tf.float32
- conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
- fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
- k = 5
+ conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
+ fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
classes = self.label_size
with tf.variable_scope(scope):
- init_conv_weight(weights, 'c1_pre', 3, 1, 32)
- init_conv_weight(weights, 'c1', 4, 32, self.dim_hidden, classes=classes)
- init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_conv_weight(weights, 'c4', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_fc_weight(weights, 'fc_dense', 2*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True)
- init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False)
+ init_conv_weight(weights, "c1_pre", 3, 1, 32)
+ init_conv_weight(weights, "c1", 4, 32, self.dim_hidden, classes=classes)
+ init_conv_weight(
+ weights, "c2", 4, self.dim_hidden, 2 * self.dim_hidden, classes=classes
+ )
+ init_conv_weight(
+ weights,
+ "c3",
+ 4,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_conv_weight(
+ weights,
+ "c4",
+ 4,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_fc_weight(
+ weights,
+ "fc_dense",
+ 2 * 4 * 4 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ spec_norm=True,
+ )
+ init_fc_weight(weights, "fc5", 2 * self.dim_hidden, 1, spec_norm=False)
return weights
- def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False, return_logit=False):
- channels = self.channels
+ def forward(
+ self,
+ inp,
+ weights,
+ reuse=False,
+ scope="",
+ stop_grad=False,
+ label=None,
+ stop_at_grad=False,
+ stop_batch=False,
+ return_logit=False,
+ ):
+ self.channels
batch_size = tf.shape(inp)[0]
inp = tf.reshape(inp, (batch_size, 64, 64, 1))
@@ -161,7 +264,7 @@ class DspritesNet(object):
if stop_grad:
for k, v in weights.items():
- if type(v) == dict:
+ if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
@@ -169,15 +272,58 @@ class DspritesNet(object):
else:
weights[k] = tf.stop_gradient(v)
- h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
- h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
- h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
- h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=True, extra_bias=True, activation=act)
- h5 = smart_conv_block(h4, weights, reuse, 'c4', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
+ h1 = smart_conv_block(
+ inp, weights, reuse, "c1_pre", use_stride=False, activation=act
+ )
+ h2 = smart_conv_block(
+ h1,
+ weights,
+ reuse,
+ "c1",
+ use_stride=True,
+ downsample=True,
+ label=label,
+ extra_bias=True,
+ activation=act,
+ )
+ h3 = smart_conv_block(
+ h2,
+ weights,
+ reuse,
+ "c2",
+ use_stride=True,
+ downsample=True,
+ label=label,
+ extra_bias=True,
+ activation=act,
+ )
+ h4 = smart_conv_block(
+ h3,
+ weights,
+ reuse,
+ "c3",
+ use_stride=True,
+ downsample=True,
+ label=label,
+ use_scale=True,
+ extra_bias=True,
+ activation=act,
+ )
+ h5 = smart_conv_block(
+ h4,
+ weights,
+ reuse,
+ "c4",
+ use_stride=True,
+ downsample=True,
+ label=label,
+ extra_bias=True,
+ activation=act,
+ )
hidden6 = tf.reshape(h5, (tf.shape(h5)[0], -1))
- hidden7 = act(smart_fc_block(hidden6, weights, reuse, 'fc_dense'))
- energy = smart_fc_block(hidden7, weights, reuse, 'fc5')
+ hidden7 = act(smart_fc_block(hidden6, weights, reuse, "fc_dense"))
+ energy = smart_fc_block(hidden7, weights, reuse, "fc5")
if return_logit:
return hidden7
@@ -185,7 +331,6 @@ class DspritesNet(object):
return energy
-
class ResNet32(object):
def __init__(self, num_channels=3, num_filters=128):
@@ -193,9 +338,9 @@ class ResNet32(object):
self.dim_hidden = num_filters
self.groupsort = groupsort()
- def construct_weights(self, scope=''):
+ def construct_weights(self, scope=""):
weights = {}
- dtype = tf.float32
+ tf.float32
if FLAGS.cclass:
classes = 10
@@ -204,23 +349,78 @@ class ResNet32(object):
with tf.variable_scope(scope):
# First block
- init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
- init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_2', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_3', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden)
- init_fc_weight(weights, 'fc5', 2*self.dim_hidden , 1, spec_norm=False)
+ init_conv_weight(weights, "c1_pre", 3, self.channels, self.dim_hidden)
+ init_res_weight(
+ weights,
+ "res_optim",
+ 3,
+ self.dim_hidden,
+ self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights, "res_1", 3, self.dim_hidden, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights,
+ "res_2",
+ 3,
+ self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_3",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_4",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_5",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_fc_weight(
+ weights, "fc_dense", 4 * 4 * 2 * self.dim_hidden, 4 * self.dim_hidden
+ )
+ init_fc_weight(weights, "fc5", 2 * self.dim_hidden, 1, spec_norm=False)
- init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
+ init_attention_weight(
+ weights,
+ "atten",
+ 2 * self.dim_hidden,
+ self.dim_hidden / 2,
+ trainable_gamma=True,
+ )
return weights
- def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
+ def forward(
+ self,
+ inp,
+ weights,
+ reuse=False,
+ scope="",
+ stop_grad=False,
+ label=None,
+ stop_at_grad=False,
+ stop_batch=False,
+ ):
weights = weights.copy()
- batch = tf.shape(inp)[0]
+ tf.shape(inp)[0]
act = tf.nn.leaky_relu
@@ -229,7 +429,7 @@ class ResNet32(object):
if stop_grad:
for k, v in weights.items():
- if type(v) == dict:
+ if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
@@ -238,23 +438,73 @@ class ResNet32(object):
weights[k] = tf.stop_gradient(v)
# Make sure gradients are modified a bit
- inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
+ inp = smart_conv_block(inp, weights, reuse, "c1_pre", use_stride=False)
- hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, act=act)
- hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, act=act)
- hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, label=label, act=act)
+ hidden1 = smart_res_block(
+ inp, weights, reuse, "res_optim", adaptive=False, label=label, act=act
+ )
+ hidden2 = smart_res_block(
+ hidden1,
+ weights,
+ reuse,
+ "res_1",
+ stop_batch=stop_batch,
+ downsample=False,
+ adaptive=False,
+ label=label,
+ act=act,
+ )
+ hidden3 = smart_res_block(
+ hidden2,
+ weights,
+ reuse,
+ "res_2",
+ stop_batch=stop_batch,
+ label=label,
+ act=act,
+ )
if FLAGS.use_attention:
- hidden4 = smart_atten_block(hidden3, weights, reuse, 'atten', stop_at_grad=stop_at_grad, label=label)
+ hidden4 = smart_atten_block(
+ hidden3, weights, reuse, "atten", stop_at_grad=stop_at_grad, label=label
+ )
else:
- hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, act=act)
+ hidden4 = smart_res_block(
+ hidden3,
+ weights,
+ reuse,
+ "res_3",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ act=act,
+ )
- hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', stop_batch=stop_batch, adaptive=False, label=label, act=act)
- compact = hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
+ hidden5 = smart_res_block(
+ hidden4,
+ weights,
+ reuse,
+ "res_4",
+ stop_batch=stop_batch,
+ adaptive=False,
+ label=label,
+ act=act,
+ )
+ compact = hidden6 = smart_res_block(
+ hidden5,
+ weights,
+ reuse,
+ "res_5",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ )
hidden6 = tf.nn.relu(hidden6)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
- hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
+ hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
@@ -269,9 +519,9 @@ class ResNet32Large(object):
self.dropout = train
self.train = train
- def construct_weights(self, scope=''):
+ def construct_weights(self, scope=""):
weights = {}
- dtype = tf.float32
+ tf.float32
if FLAGS.cclass:
classes = 10
@@ -280,32 +530,101 @@ class ResNet32Large(object):
with tf.variable_scope(scope):
# First block
- init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
- init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
+ init_conv_weight(weights, "c1_pre", 3, self.channels, self.dim_hidden)
+ init_res_weight(
+ weights,
+ "res_optim",
+ 3,
+ self.dim_hidden,
+ self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights, "res_1", 3, self.dim_hidden, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights, "res_2", 3, self.dim_hidden, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights,
+ "res_3",
+ 3,
+ self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_4",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_5",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_6",
+ 3,
+ 2 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_7",
+ 3,
+ 4 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_8",
+ 3,
+ 4 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_fc_weight(weights, "fc5", 4 * self.dim_hidden, 1, spec_norm=False)
- init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden, trainable_gamma=True)
+ init_attention_weight(
+ weights,
+ "atten",
+ 2 * self.dim_hidden,
+ self.dim_hidden,
+ trainable_gamma=True,
+ )
return weights
- def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
+ def forward(
+ self,
+ inp,
+ weights,
+ reuse=False,
+ scope="",
+ stop_grad=False,
+ label=None,
+ stop_at_grad=False,
+ stop_batch=False,
+ ):
weights = weights.copy()
- batch = tf.shape(inp)[0]
+ tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
- if type(v) == dict:
+ if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
@@ -314,27 +633,122 @@ class ResNet32Large(object):
weights[k] = tf.stop_gradient(v)
# Make sure gradients are modified a bit
- inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
+ inp = smart_conv_block(inp, weights, reuse, "c1_pre", use_stride=False)
dropout = self.dropout
train = self.train
- hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, dropout=dropout, train=train)
- hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train)
- hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train)
- hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train)
+ hidden1 = smart_res_block(
+ inp,
+ weights,
+ reuse,
+ "res_optim",
+ adaptive=False,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
+ hidden2 = smart_res_block(
+ hidden1,
+ weights,
+ reuse,
+ "res_1",
+ stop_batch=stop_batch,
+ downsample=False,
+ adaptive=False,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
+ hidden3 = smart_res_block(
+ hidden2,
+ weights,
+ reuse,
+ "res_2",
+ stop_batch=stop_batch,
+ downsample=False,
+ adaptive=False,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
+ hidden4 = smart_res_block(
+ hidden3,
+ weights,
+ reuse,
+ "res_3",
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
if FLAGS.use_attention:
- hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
+ hidden5 = smart_atten_block(
+ hidden4, weights, reuse, "atten", stop_at_grad=stop_at_grad
+ )
else:
- hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
+ hidden5 = smart_res_block(
+ hidden4,
+ weights,
+ reuse,
+ "res_4",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
- hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
+ hidden6 = smart_res_block(
+ hidden5,
+ weights,
+ reuse,
+ "res_5",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
- hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train)
- hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
+ hidden7 = smart_res_block(
+ hidden6,
+ weights,
+ reuse,
+ "res_6",
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
+ hidden8 = smart_res_block(
+ hidden7,
+ weights,
+ reuse,
+ "res_7",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
- compact = hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
+ compact = hidden9 = smart_res_block(
+ hidden8,
+ weights,
+ reuse,
+ "res_8",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
if FLAGS.cclass:
hidden6 = tf.nn.leaky_relu(hidden9)
@@ -342,7 +756,7 @@ class ResNet32Large(object):
hidden6 = tf.nn.relu(hidden9)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
- hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
+ hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
@@ -357,9 +771,9 @@ class ResNet32Wider(object):
self.dropout = train
self.train = train
- def construct_weights(self, scope=''):
+ def construct_weights(self, scope=""):
weights = {}
- dtype = tf.float32
+ tf.float32
if FLAGS.cclass and FLAGS.dataset == "cifar10":
classes = 10
@@ -370,32 +784,96 @@ class ResNet32Wider(object):
with tf.variable_scope(scope):
# First block
- init_conv_weight(weights, 'c1_pre', 3, self.channels, 128)
- init_res_weight(weights, 'res_optim', 3, 128, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
+ init_conv_weight(weights, "c1_pre", 3, self.channels, 128)
+ init_res_weight(
+ weights, "res_optim", 3, 128, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights, "res_1", 3, self.dim_hidden, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights, "res_2", 3, self.dim_hidden, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights,
+ "res_3",
+ 3,
+ self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_4",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_5",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_6",
+ 3,
+ 2 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_7",
+ 3,
+ 4 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_8",
+ 3,
+ 4 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_fc_weight(weights, "fc5", 4 * self.dim_hidden, 1, spec_norm=False)
- init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
+ init_attention_weight(
+ weights,
+ "atten",
+ self.dim_hidden,
+ self.dim_hidden / 2,
+ trainable_gamma=True,
+ )
return weights
- def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
+ def forward(
+ self,
+ inp,
+ weights,
+ reuse=False,
+ scope="",
+ stop_grad=False,
+ label=None,
+ stop_at_grad=False,
+ stop_batch=False,
+ ):
weights = weights.copy()
- batch = tf.shape(inp)[0]
+ tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
- if type(v) == dict:
+ if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
@@ -409,28 +887,139 @@ class ResNet32Wider(object):
act = tf.nn.leaky_relu
# Make sure gradients are modified a bit
- inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
+ inp = smart_conv_block(
+ inp, weights, reuse, "c1_pre", use_stride=False, activation=act
+ )
dropout = self.dropout
train = self.train
- hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=True, label=label, dropout=dropout, train=train)
+ hidden1 = smart_res_block(
+ inp,
+ weights,
+ reuse,
+ "res_optim",
+ adaptive=True,
+ label=label,
+ dropout=dropout,
+ train=train,
+ )
if FLAGS.use_attention:
- hidden2 = smart_atten_block(hidden1, weights, reuse, 'atten', train=train, dropout=dropout, stop_at_grad=stop_at_grad)
+ hidden2 = smart_atten_block(
+ hidden1,
+ weights,
+ reuse,
+ "atten",
+ train=train,
+ dropout=dropout,
+ stop_at_grad=stop_at_grad,
+ )
else:
- hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act)
+ hidden2 = smart_res_block(
+ hidden1,
+ weights,
+ reuse,
+ "res_1",
+ stop_batch=stop_batch,
+ downsample=False,
+ adaptive=False,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
- hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act)
- hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
+ hidden3 = smart_res_block(
+ hidden2,
+ weights,
+ reuse,
+ "res_2",
+ stop_batch=stop_batch,
+ downsample=False,
+ adaptive=False,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
+ hidden4 = smart_res_block(
+ hidden3,
+ weights,
+ reuse,
+ "res_3",
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
- hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
+ hidden5 = smart_res_block(
+ hidden4,
+ weights,
+ reuse,
+ "res_4",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
- hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
+ hidden6 = smart_res_block(
+ hidden5,
+ weights,
+ reuse,
+ "res_5",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
- hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
- hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
+ hidden7 = smart_res_block(
+ hidden6,
+ weights,
+ reuse,
+ "res_6",
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
+ hidden8 = smart_res_block(
+ hidden7,
+ weights,
+ reuse,
+ "res_7",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
- hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
+ hidden9 = smart_res_block(
+ hidden8,
+ weights,
+ reuse,
+ "res_8",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
if FLAGS.swish_act:
hidden6 = act(hidden9)
@@ -438,7 +1027,7 @@ class ResNet32Wider(object):
hidden6 = tf.nn.relu(hidden9)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
- hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
+ hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
return energy
@@ -450,9 +1039,9 @@ class ResNet32Larger(object):
self.channels = num_channels
self.dim_hidden = num_filters
- def construct_weights(self, scope=''):
+ def construct_weights(self, scope=""):
weights = {}
- dtype = tf.float32
+ tf.float32
if FLAGS.cclass:
classes = 10
@@ -461,39 +1050,142 @@ class ResNet32Larger(object):
with tf.variable_scope(scope):
# First block
- init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
- init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_2a', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_2b', 3, self.dim_hidden, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_5a', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_5b', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_8a', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_8b', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden)
- init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
+ init_conv_weight(weights, "c1_pre", 3, self.channels, self.dim_hidden)
+ init_res_weight(
+ weights,
+ "res_optim",
+ 3,
+ self.dim_hidden,
+ self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights, "res_1", 3, self.dim_hidden, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights, "res_2", 3, self.dim_hidden, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights, "res_2a", 3, self.dim_hidden, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights, "res_2b", 3, self.dim_hidden, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights,
+ "res_3",
+ 3,
+ self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_4",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_5",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_5a",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_5b",
+ 3,
+ 2 * self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_6",
+ 3,
+ 2 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_7",
+ 3,
+ 4 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_8",
+ 3,
+ 4 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_8a",
+ 3,
+ 4 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_8b",
+ 3,
+ 4 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_fc_weight(
+ weights, "fc_dense", 4 * 4 * 2 * self.dim_hidden, 4 * self.dim_hidden
+ )
+ init_fc_weight(weights, "fc5", 4 * self.dim_hidden, 1, spec_norm=False)
- init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
+ init_attention_weight(
+ weights,
+ "atten",
+ 2 * self.dim_hidden,
+ self.dim_hidden / 2,
+ trainable_gamma=True,
+ )
return weights
- def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
+ def forward(
+ self,
+ inp,
+ weights,
+ reuse=False,
+ scope="",
+ stop_grad=False,
+ label=None,
+ stop_at_grad=False,
+ stop_batch=False,
+ ):
weights = weights.copy()
- batch = tf.shape(inp)[0]
+ tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
- if type(v) == dict:
+ if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
@@ -502,29 +1194,145 @@ class ResNet32Larger(object):
weights[k] = tf.stop_gradient(v)
# Make sure gradients are modified a bit
- inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
+ inp = smart_conv_block(inp, weights, reuse, "c1_pre", use_stride=False)
- hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label)
- hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
- hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
- hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2a', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
- hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2b', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
- hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label)
+ hidden1 = smart_res_block(
+ inp, weights, reuse, "res_optim", adaptive=False, label=label
+ )
+ hidden2 = smart_res_block(
+ hidden1,
+ weights,
+ reuse,
+ "res_1",
+ stop_batch=stop_batch,
+ downsample=False,
+ adaptive=False,
+ label=label,
+ )
+ hidden3 = smart_res_block(
+ hidden2,
+ weights,
+ reuse,
+ "res_2",
+ stop_batch=stop_batch,
+ downsample=False,
+ adaptive=False,
+ label=label,
+ )
+ hidden3 = smart_res_block(
+ hidden3,
+ weights,
+ reuse,
+ "res_2a",
+ stop_batch=stop_batch,
+ downsample=False,
+ adaptive=False,
+ label=label,
+ )
+ hidden3 = smart_res_block(
+ hidden3,
+ weights,
+ reuse,
+ "res_2b",
+ stop_batch=stop_batch,
+ downsample=False,
+ adaptive=False,
+ label=label,
+ )
+ hidden4 = smart_res_block(
+ hidden3, weights, reuse, "res_3", stop_batch=stop_batch, label=label
+ )
if FLAGS.use_attention:
- hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
+ hidden5 = smart_atten_block(
+ hidden4, weights, reuse, "atten", stop_at_grad=stop_at_grad
+ )
else:
- hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
+ hidden5 = smart_res_block(
+ hidden4,
+ weights,
+ reuse,
+ "res_4",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ )
- hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
+ hidden6 = smart_res_block(
+ hidden5,
+ weights,
+ reuse,
+ "res_5",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ )
- hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
- hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
- hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label)
- hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
- hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
- hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
- compact = hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
+ hidden6 = smart_res_block(
+ hidden6,
+ weights,
+ reuse,
+ "res_5a",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ )
+ hidden6 = smart_res_block(
+ hidden6,
+ weights,
+ reuse,
+ "res_5b",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ )
+ hidden7 = smart_res_block(
+ hidden6, weights, reuse, "res_6", stop_batch=stop_batch, label=label
+ )
+ hidden8 = smart_res_block(
+ hidden7,
+ weights,
+ reuse,
+ "res_7",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ )
+ hidden9 = smart_res_block(
+ hidden8,
+ weights,
+ reuse,
+ "res_8",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ )
+ hidden9 = smart_res_block(
+ hidden9,
+ weights,
+ reuse,
+ "res_8a",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ )
+ compact = hidden9 = smart_res_block(
+ hidden9,
+ weights,
+ reuse,
+ "res_8b",
+ adaptive=False,
+ downsample=False,
+ stop_batch=stop_batch,
+ label=label,
+ )
if FLAGS.cclass:
hidden6 = tf.nn.leaky_relu(hidden9)
@@ -532,7 +1340,7 @@ class ResNet32Larger(object):
hidden6 = tf.nn.relu(hidden9)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
- hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
+ hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
@@ -549,39 +1357,90 @@ class ResNet128(object):
self.dropout = train
self.train = train
- def construct_weights(self, scope=''):
+ def construct_weights(self, scope=""):
weights = {}
- dtype = tf.float32
+ tf.float32
classes = 1000
with tf.variable_scope(scope):
# First block
- init_conv_weight(weights, 'c1_pre', 3, self.channels, 64)
- init_res_weight(weights, 'res_optim', 3, 64, self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 8*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_9', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes)
- init_res_weight(weights, 'res_10', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes)
- init_fc_weight(weights, 'fc5', 8*self.dim_hidden , 1, spec_norm=False)
+ init_conv_weight(weights, "c1_pre", 3, self.channels, 64)
+ init_res_weight(
+ weights, "res_optim", 3, 64, self.dim_hidden, classes=classes
+ )
+ init_res_weight(
+ weights,
+ "res_3",
+ 3,
+ self.dim_hidden,
+ 2 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_5",
+ 3,
+ 2 * self.dim_hidden,
+ 4 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_7",
+ 3,
+ 4 * self.dim_hidden,
+ 8 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_9",
+ 3,
+ 8 * self.dim_hidden,
+ 8 * self.dim_hidden,
+ classes=classes,
+ )
+ init_res_weight(
+ weights,
+ "res_10",
+ 3,
+ 8 * self.dim_hidden,
+ 8 * self.dim_hidden,
+ classes=classes,
+ )
+ init_fc_weight(weights, "fc5", 8 * self.dim_hidden, 1, spec_norm=False)
-
- init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2., trainable_gamma=True)
+ init_attention_weight(
+ weights,
+ "atten",
+ self.dim_hidden,
+ self.dim_hidden / 2.0,
+ trainable_gamma=True,
+ )
return weights
- def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
+ def forward(
+ self,
+ inp,
+ weights,
+ reuse=False,
+ scope="",
+ stop_grad=False,
+ label=None,
+ stop_at_grad=False,
+ stop_batch=False,
+ ):
weights = weights.copy()
- batch = tf.shape(inp)[0]
+ tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
-
if stop_grad:
for k, v in weights.items():
- if type(v) == dict:
+ if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
@@ -598,17 +1457,91 @@ class ResNet128(object):
train = self.train
# Make sure gradients are modified a bit
- inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
- hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', label=label, dropout=dropout, train=train, downsample=True, adaptive=False)
+ inp = smart_conv_block(
+ inp, weights, reuse, "c1_pre", use_stride=False, activation=act
+ )
+ hidden1 = smart_res_block(
+ inp,
+ weights,
+ reuse,
+ "res_optim",
+ label=label,
+ dropout=dropout,
+ train=train,
+ downsample=True,
+ adaptive=False,
+ )
if FLAGS.use_attention:
- hidden1 = smart_atten_block(hidden1, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
+ hidden1 = smart_atten_block(
+ hidden1, weights, reuse, "atten", stop_at_grad=stop_at_grad
+ )
- hidden2 = smart_res_block(hidden1, weights, reuse, 'res_3', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act)
- hidden3 = smart_res_block(hidden2, weights, reuse, 'res_5', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act)
- hidden4 = smart_res_block(hidden3, weights, reuse, 'res_7', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=True)
- hidden5 = smart_res_block(hidden4, weights, reuse, 'res_9', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=False)
- hidden6 = smart_res_block(hidden5, weights, reuse, 'res_10', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=False, adaptive=False)
+ hidden2 = smart_res_block(
+ hidden1,
+ weights,
+ reuse,
+ "res_3",
+ stop_batch=stop_batch,
+ downsample=True,
+ adaptive=True,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
+ hidden3 = smart_res_block(
+ hidden2,
+ weights,
+ reuse,
+ "res_5",
+ stop_batch=stop_batch,
+ downsample=True,
+ adaptive=True,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ )
+ hidden4 = smart_res_block(
+ hidden3,
+ weights,
+ reuse,
+ "res_7",
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ downsample=True,
+ adaptive=True,
+ )
+ hidden5 = smart_res_block(
+ hidden4,
+ weights,
+ reuse,
+ "res_9",
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ downsample=True,
+ adaptive=False,
+ )
+ hidden6 = smart_res_block(
+ hidden5,
+ weights,
+ reuse,
+ "res_10",
+ stop_batch=stop_batch,
+ label=label,
+ dropout=dropout,
+ train=train,
+ act=act,
+ downsample=False,
+ adaptive=False,
+ )
if FLAGS.swish_act:
hidden6 = act(hidden6)
@@ -616,7 +1549,7 @@ class ResNet128(object):
hidden6 = tf.nn.relu(hidden6)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
- hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
+ hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
return energy
diff --git a/EBMs/test_inception.py b/EBMs/test_inception.py
index ca7e55b..e33df83 100644
--- a/EBMs/test_inception.py
+++ b/EBMs/test_inception.py
@@ -1,53 +1,56 @@
-import tensorflow as tf
-import numpy as np
-from tensorflow.python.platform import flags
-from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
import os.path as osp
-import os
-from utils import optimistic_restore, remap_restore, optimistic_remap_restore
-from tqdm import tqdm
import random
-from scipy.misc import imsave
-from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader
-from torch.utils.data import DataLoader
-from baselines.common.tf_util import initialize
import horovod.tensorflow as hvd
+import numpy as np
+import tensorflow as tf
+from baselines.common.tf_util import initialize
+from data import Cifar10, Imagenet, TFImagenetLoader
+from fid import get_fid_score
+from inception import get_inception_score
+from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
+from scipy.misc import imsave
+from tensorflow.python.platform import flags
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from utils import optimistic_remap_restore
+
hvd.init()
-from inception import get_inception_score
-from fid import get_fid_score
-flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
-flags.DEFINE_string('exp', 'default', 'name of experiments')
-flags.DEFINE_bool('cclass', False, 'whether to condition on class')
+flags.DEFINE_string(
+ "logdir", "cachedir", "location where log of experiments will be stored"
+)
+flags.DEFINE_string("exp", "default", "name of experiments")
+flags.DEFINE_bool("cclass", False, "whether to condition on class")
# Architecture settings
-flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
-flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
-flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
-flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
-flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent')
-flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
-flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images')
-flags.DEFINE_integer('batch_size', 512, 'batch size')
-flags.DEFINE_integer('resume_iter', -1, 'resume iteration')
-flags.DEFINE_integer('ensemble', 10, 'number of ensembles')
-flags.DEFINE_integer('im_number', 50000, 'number of ensembles')
-flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations')
-flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output')
-flags.DEFINE_integer('idx', 0, 'save index')
-flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing')
-flags.DEFINE_bool('scaled', True, 'whether to scale noise added')
-flags.DEFINE_bool('large_model', False, 'whether to use a small or large model')
-flags.DEFINE_bool('larger_model', False, 'Whether to use a large model')
-flags.DEFINE_bool('wider_model', False, 'Whether to use a large model')
-flags.DEFINE_bool('single', False, 'single ')
-flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
-flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull')
+flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
+flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
+flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
+flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
+flags.DEFINE_float("step_lr", 10.0, "Size of steps for gradient descent")
+flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
+flags.DEFINE_float("proj_norm", 0.05, "Maximum change of input images")
+flags.DEFINE_integer("batch_size", 512, "batch size")
+flags.DEFINE_integer("resume_iter", -1, "resume iteration")
+flags.DEFINE_integer("ensemble", 10, "number of ensembles")
+flags.DEFINE_integer("im_number", 50000, "number of ensembles")
+flags.DEFINE_integer("repeat_scale", 100, "number of repeat iterations")
+flags.DEFINE_float("noise_scale", 0.005, "amount of noise to output")
+flags.DEFINE_integer("idx", 0, "save index")
+flags.DEFINE_integer("nomix", 10, "number of intervals to stop mixing")
+flags.DEFINE_bool("scaled", True, "whether to scale noise added")
+flags.DEFINE_bool("large_model", False, "whether to use a small or large model")
+flags.DEFINE_bool("larger_model", False, "Whether to use a large model")
+flags.DEFINE_bool("wider_model", False, "Whether to use a large model")
+flags.DEFINE_bool("single", False, "single ")
+flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
+flags.DEFINE_string("dataset", "cifar10", "cifar10 or imagenet or imagenetfull")
FLAGS = flags.FLAGS
+
class InceptionReplayBuffer(object):
def __init__(self, size):
"""Create Replay buffer.
@@ -72,14 +75,16 @@ class InceptionReplayBuffer(object):
self._label_storage.extend(list(labels))
else:
if batch_size + self._next_idx < self._maxsize:
- self._storage[self._next_idx:self._next_idx+batch_size] = list(ims)
- self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels)
+ self._storage[self._next_idx : self._next_idx + batch_size] = list(ims)
+ self._label_storage[self._next_idx : self._next_idx + batch_size] = (
+ list(labels)
+ )
else:
split_idx = self._maxsize - self._next_idx
- self._storage[self._next_idx:] = list(ims)[:split_idx]
- self._storage[:batch_size-split_idx] = list(ims)[split_idx:]
- self._label_storage[self._next_idx:] = list(labels)[:split_idx]
- self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:]
+ self._storage[self._next_idx :] = list(ims)[:split_idx]
+ self._storage[: batch_size - split_idx] = list(ims)[split_idx:]
+ self._label_storage[self._next_idx :] = list(labels)[:split_idx]
+ self._label_storage[: batch_size - split_idx] = list(labels)[split_idx:]
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
@@ -123,12 +128,13 @@ class InceptionReplayBuffer(object):
def rescale_im(im):
return np.clip(im * 256, 0, 255).astype(np.uint8)
+
def compute_inception(sess, target_vars):
- X_START = target_vars['X_START']
- Y_GT = target_vars['Y_GT']
- X_finals = target_vars['X_finals']
- NOISE_SCALE = target_vars['NOISE_SCALE']
- energy_noise = target_vars['energy_noise']
+ X_START = target_vars["X_START"]
+ Y_GT = target_vars["Y_GT"]
+ X_finals = target_vars["X_finals"]
+ NOISE_SCALE = target_vars["NOISE_SCALE"]
+ energy_noise = target_vars["energy_noise"]
size = FLAGS.im_number
num_steps = size // 1000
@@ -136,16 +142,21 @@ def compute_inception(sess, target_vars):
images = []
test_ims = []
-
if FLAGS.dataset == "cifar10":
test_dataset = Cifar10(full=True, noise=False)
elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
test_dataset = Imagenet(train=False)
if FLAGS.dataset != "imagenetfull":
- test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
+ test_dataloader = DataLoader(
+ test_dataset,
+ batch_size=FLAGS.batch_size,
+ num_workers=4,
+ shuffle=True,
+ drop_last=False,
+ )
else:
- test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)
+ test_dataloader = TFImagenetLoader("test", FLAGS.batch_size, 0, 1)
for data_corrupt, data, label_gt in tqdm(test_dataloader):
data = data.numpy()
@@ -155,7 +166,6 @@ def compute_inception(sess, target_vars):
test_ims = test_ims[:60000]
break
-
# n = min(len(images), len(test_ims))
print(len(test_ims))
# fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
@@ -187,12 +197,14 @@ def compute_inception(sess, target_vars):
x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
label = np.random.randint(0, classes, (FLAGS.batch_size))
label = identity[label]
- x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
+ x_new = sess.run(
+ [x_final], {X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale}
+ )[0]
data_buffer.add(x_new, label)
else:
(x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
- keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
- label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
+ keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99
+ label_keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9
label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
label_corrupt = identity[label_corrupt]
x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
@@ -203,7 +215,10 @@ def compute_inception(sess, target_vars):
# else:
# noise_scale = [0.7]
- x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
+ x_new, e_noise = sess.run(
+ [x_final, energy_noise],
+ {X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale},
+ )
data_buffer.set_elms(idx, x_new, label)
if FLAGS.im_number != 50000:
@@ -216,14 +231,22 @@ def compute_inception(sess, target_vars):
images.extend(list(ims))
- saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx))
+ saveim = osp.join("sandbox_cachedir", FLAGS.exp, "test{}.png".format(FLAGS.idx))
ims = ims[:100]
if FLAGS.dataset != "imagenetfull":
- im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
+ im_panel = (
+ ims.reshape((10, 10, 32, 32, 3))
+ .transpose((0, 2, 1, 3, 4))
+ .reshape((320, 320, 3))
+ )
else:
- im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3))
+ im_panel = (
+ ims.reshape((10, 10, 128, 128, 3))
+ .transpose((0, 2, 1, 3, 4))
+ .reshape((1280, 1280, 3))
+ )
imsave(saveim, im_panel)
print("Saved image!!!!")
@@ -237,8 +260,6 @@ def compute_inception(sess, target_vars):
print("FID of score {}".format(fid))
-
-
def main(model_list):
if FLAGS.dataset == "imagenetfull":
@@ -259,45 +280,55 @@ def main(model_list):
weights = []
for i, model_num in enumerate(model_list):
- weight = model.construct_weights('context_{}'.format(i))
+ weight = model.construct_weights("context_{}".format(i))
initialize()
- save_file = osp.join(logdir, 'model_{}'.format(model_num))
+ save_file = osp.join(logdir, "model_{}".format(model_num))
- v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i))
- v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list}
+ v_list = tf.get_collection(
+ tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(i)
+ )
+ v_map = {
+ (v.name.replace("context_{}".format(i), "context_0")[:-2]): v
+ for v in v_list
+ }
saver = tf.train.Saver(v_map)
try:
saver.restore(sess, save_file)
- except:
+ except BaseException:
optimistic_remap_restore(sess, save_file, i)
weights.append(weight)
-
if FLAGS.dataset == "imagenetfull":
- X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32)
+ X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
else:
- X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
+ X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
if FLAGS.dataset == "cifar10":
- Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
+ Y_GT = tf.placeholder(shape=(None, 10), dtype=tf.float32)
else:
- Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
+ Y_GT = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32)
X_finals = []
-
# Seperate loops
for weight in weights:
X = X_START
steps = tf.constant(0)
- c = lambda i, x: tf.less(i, FLAGS.num_steps)
+
+ def c(i, x):
+ return tf.less(i, FLAGS.num_steps)
+
def langevin_step(counter, X):
scale_rate = 1
- X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE)
+ X = X + tf.random_normal(
+ tf.shape(X),
+ mean=0.0,
+ stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE,
+ )
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
@@ -305,7 +336,7 @@ def main(model_list):
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
- X = X - FLAGS.step_lr * x_grad * scale_rate
+ X = X - FLAGS.step_lr * x_grad * scale_rate
X = tf.clip_by_value(X, 0, 1)
counter = counter + 1
@@ -318,16 +349,16 @@ def main(model_list):
X_finals.append(X_final)
target_vars = {}
- target_vars['X_START'] = X_START
- target_vars['Y_GT'] = Y_GT
- target_vars['X_finals'] = X_finals
- target_vars['NOISE_SCALE'] = NOISE_SCALE
- target_vars['energy_noise'] = energy_noise
+ target_vars["X_START"] = X_START
+ target_vars["Y_GT"] = Y_GT
+ target_vars["X_finals"] = X_finals
+ target_vars["NOISE_SCALE"] = NOISE_SCALE
+ target_vars["energy_noise"] = energy_noise
compute_inception(sess, target_vars)
if __name__ == "__main__":
# model_list = [117000, 116700]
- model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)]
+ model_list = [FLAGS.resume_iter - 300 * i for i in range(FLAGS.ensemble)]
main(model_list)
diff --git a/EBMs/train.py b/EBMs/train.py
index 716c4ea..b4b64b2 100644
--- a/EBMs/train.py
+++ b/EBMs/train.py
@@ -1,34 +1,38 @@
-import tensorflow as tf
-import numpy as np
-from tensorflow.python.platform import flags
-
-from data import Imagenet, Cifar10, DSprites, Mnist, TFImagenetLoader
-from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, MnistNet, ResNet128
-import os.path as osp
import os
-from baselines.logger import TensorBoardOutputFormat
-from utils import average_gradients, ReplayBuffer, optimistic_restore
-from tqdm import tqdm
+import os.path as osp
import random
-from torch.utils.data import DataLoader
-import time as time
-from io import StringIO
-from tensorflow.core.util import event_pb2
-import torch
-import numpy as np
-from custom_adam import AdamOptimizer
-from scipy.misc import imsave
-import matplotlib.pyplot as plt
-from hmc import hmc
+import horovod.tensorflow as hvd
+import numpy as np
+import tensorflow as tf
+import torch
+from baselines.logger import TensorBoardOutputFormat
+from custom_adam import AdamOptimizer
+from data import Cifar10, DSprites, Imagenet, Mnist, TFImagenetLoader
+from hmc import hmc
+from inception import get_inception_score
+from models import (
+ DspritesNet,
+ MnistNet,
+ ResNet32,
+ ResNet32Large,
+ ResNet32Larger,
+ ResNet32Wider,
+ ResNet128,
+)
from mpi4py import MPI
+from tensorflow.core.util import event_pb2
+from tensorflow.python.platform import flags
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from utils import ReplayBuffer, average_gradients, optimistic_restore
+
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
-import horovod.tensorflow as hvd
+
hvd.init()
-from inception import get_inception_score
torch.manual_seed(hvd.rank())
np.random.seed(hvd.rank())
@@ -38,67 +42,81 @@ FLAGS = flags.FLAGS
# Dataset Options
-flags.DEFINE_string('datasource', 'random',
- 'initialization for chains, either random or default (decorruption)')
-flags.DEFINE_string('dataset','mnist',
- 'dsprites, cifar10, imagenet (32x32) or imagenetfull (128x128)')
-flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
-flags.DEFINE_bool('single', False, 'whether to debug by training on a single image')
-flags.DEFINE_integer('data_workers', 4,
- 'Number of different data workers to load data in parallel')
+flags.DEFINE_string(
+ "datasource",
+ "random",
+ "initialization for chains, either random or default (decorruption)",
+)
+flags.DEFINE_string(
+ "dataset", "mnist", "dsprites, cifar10, imagenet (32x32) or imagenetfull (128x128)"
+)
+flags.DEFINE_integer("batch_size", 256, "Size of inputs")
+flags.DEFINE_bool("single", False, "whether to debug by training on a single image")
+flags.DEFINE_integer(
+ "data_workers", 4, "Number of different data workers to load data in parallel"
+)
# General Experiment Settings
-flags.DEFINE_string('logdir', 'cachedir',
- 'location where log of experiments will be stored')
-flags.DEFINE_string('exp', 'default', 'name of experiments')
-flags.DEFINE_integer('log_interval', 10, 'log outputs every so many batches')
-flags.DEFINE_integer('save_interval', 1000,'save outputs every so many batches')
-flags.DEFINE_integer('test_interval', 1000,'evaluate outputs every so many batches')
-flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
-flags.DEFINE_bool('train', True, 'whether to train or test')
-flags.DEFINE_integer('epoch_num', 10000, 'Number of Epochs to train on')
-flags.DEFINE_float('lr', 3e-4, 'Learning for training')
-flags.DEFINE_integer('num_gpus', 1, 'number of gpus to train on')
+flags.DEFINE_string(
+ "logdir", "cachedir", "location where log of experiments will be stored"
+)
+flags.DEFINE_string("exp", "default", "name of experiments")
+flags.DEFINE_integer("log_interval", 10, "log outputs every so many batches")
+flags.DEFINE_integer("save_interval", 1000, "save outputs every so many batches")
+flags.DEFINE_integer("test_interval", 1000, "evaluate outputs every so many batches")
+flags.DEFINE_integer("resume_iter", -1, "iteration to resume training from")
+flags.DEFINE_bool("train", True, "whether to train or test")
+flags.DEFINE_integer("epoch_num", 10000, "Number of Epochs to train on")
+flags.DEFINE_float("lr", 3e-4, "Learning for training")
+flags.DEFINE_integer("num_gpus", 1, "number of gpus to train on")
# EBM Specific Experiments Settings
-flags.DEFINE_float('ml_coeff', 1.0, 'Maximum Likelihood Coefficients')
-flags.DEFINE_float('l2_coeff', 1.0, 'L2 Penalty training')
-flags.DEFINE_bool('cclass', False, 'Whether to conditional training in models')
-flags.DEFINE_bool('model_cclass', False,'use unsupervised clustering to infer fake labels')
-flags.DEFINE_integer('temperature', 1, 'Temperature for energy function')
-flags.DEFINE_string('objective', 'cd', 'use either contrastive divergence objective(least stable),'
- 'logsumexp(more stable)'
- 'softplus(most stable)')
-flags.DEFINE_bool('zero_kl', False, 'whether to zero out the kl loss')
+flags.DEFINE_float("ml_coeff", 1.0, "Maximum Likelihood Coefficients")
+flags.DEFINE_float("l2_coeff", 1.0, "L2 Penalty training")
+flags.DEFINE_bool("cclass", False, "Whether to conditional training in models")
+flags.DEFINE_bool(
+ "model_cclass", False, "use unsupervised clustering to infer fake labels"
+)
+flags.DEFINE_integer("temperature", 1, "Temperature for energy function")
+flags.DEFINE_string(
+ "objective",
+ "cd",
+ "use either contrastive divergence objective(least stable),"
+ "logsumexp(more stable)"
+ "softplus(most stable)",
+)
+flags.DEFINE_bool("zero_kl", False, "whether to zero out the kl loss")
# Setting for MCMC sampling
-flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
-flags.DEFINE_string('proj_norm_type', 'li', 'Either li or l2 ball projection')
-flags.DEFINE_integer('num_steps', 20, 'Steps of gradient descent for training')
-flags.DEFINE_float('step_lr', 1.0, 'Size of steps for gradient descent')
-flags.DEFINE_bool('replay_batch', False, 'Use MCMC chains initialized from a replay buffer.')
-flags.DEFINE_bool('hmc', False, 'Whether to use HMC sampling to train models')
-flags.DEFINE_float('noise_scale', 1.,'Relative amount of noise for MCMC')
-flags.DEFINE_bool('pcd', False, 'whether to use pcd training instead')
+flags.DEFINE_float("proj_norm", 0.0, "Maximum change of input images")
+flags.DEFINE_string("proj_norm_type", "li", "Either li or l2 ball projection")
+flags.DEFINE_integer("num_steps", 20, "Steps of gradient descent for training")
+flags.DEFINE_float("step_lr", 1.0, "Size of steps for gradient descent")
+flags.DEFINE_bool(
+ "replay_batch", False, "Use MCMC chains initialized from a replay buffer."
+)
+flags.DEFINE_bool("hmc", False, "Whether to use HMC sampling to train models")
+flags.DEFINE_float("noise_scale", 1.0, "Relative amount of noise for MCMC")
+flags.DEFINE_bool("pcd", False, "whether to use pcd training instead")
# Architecture Settings
-flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets')
-flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
-flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
-flags.DEFINE_bool('large_model', False, 'whether to use a large model')
-flags.DEFINE_bool('larger_model', False, 'Deeper ResNet32 Network')
-flags.DEFINE_bool('wider_model', False, 'Wider ResNet32 Network')
+flags.DEFINE_integer("num_filters", 64, "number of filters for conv nets")
+flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
+flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
+flags.DEFINE_bool("large_model", False, "whether to use a large model")
+flags.DEFINE_bool("larger_model", False, "Deeper ResNet32 Network")
+flags.DEFINE_bool("wider_model", False, "Wider ResNet32 Network")
# Dataset settings
-flags.DEFINE_bool('mixup', False, 'whether to add mixup to training images')
-flags.DEFINE_bool('augment', False, 'whether to augmentations to images')
-flags.DEFINE_float('rescale', 1.0, 'Factor to rescale inputs from 0-1 box')
+flags.DEFINE_bool("mixup", False, "whether to add mixup to training images")
+flags.DEFINE_bool("augment", False, "whether to augmentations to images")
+flags.DEFINE_float("rescale", 1.0, "Factor to rescale inputs from 0-1 box")
# Dsprites specific experiments
-flags.DEFINE_bool('cond_shape', False, 'condition of shape type')
-flags.DEFINE_bool('cond_size', False, 'condition of shape size')
-flags.DEFINE_bool('cond_pos', False, 'condition of position loc')
-flags.DEFINE_bool('cond_rot', False, 'condition of rot')
+flags.DEFINE_bool("cond_shape", False, "condition of shape type")
+flags.DEFINE_bool("cond_size", False, "condition of shape size")
+flags.DEFINE_bool("cond_pos", False, "condition of position loc")
+flags.DEFINE_bool("cond_rot", False, "condition of rot")
FLAGS.step_lr = FLAGS.step_lr * FLAGS.rescale
@@ -113,14 +131,16 @@ def compress_x_mod(x_mod):
def decompress_x_mod(x_mod):
- x_mod = x_mod / 256 * FLAGS.rescale + \
- np.random.uniform(0, 1 / 256 * FLAGS.rescale, x_mod.shape)
+ x_mod = x_mod / 256 * FLAGS.rescale + np.random.uniform(
+ 0, 1 / 256 * FLAGS.rescale, x_mod.shape
+ )
return x_mod
def make_image(tensor):
"""Convert an numpy representation image to Image protobuf"""
from PIL import Image
+
if len(tensor.shape) == 4:
_, height, width, channel = tensor.shape
elif len(tensor.shape) == 3:
@@ -131,14 +151,17 @@ def make_image(tensor):
tensor = tensor.astype(np.uint8)
image = Image.fromarray(tensor)
import io
+
output = io.BytesIO()
- image.save(output, format='PNG')
+ image.save(output, format="PNG")
image_string = output.getvalue()
output.close()
- return tf.Summary.Image(height=height,
- width=width,
- colorspace=channel,
- encoded_image_string=image_string)
+ return tf.Summary.Image(
+ height=height,
+ width=width,
+ colorspace=channel,
+ encoded_image_string=image_string,
+ )
def log_image(im, logger, tag, step=0):
@@ -154,37 +177,39 @@ def log_image(im, logger, tag, step=0):
def rescale_im(image):
image = np.clip(image, 0, FLAGS.rescale)
- if FLAGS.dataset == 'mnist' or FLAGS.dataset == 'dsprites':
- return (np.clip((FLAGS.rescale - image) * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
+ if FLAGS.dataset == "mnist" or FLAGS.dataset == "dsprites":
+ return (np.clip((FLAGS.rescale - image) * 256 / FLAGS.rescale, 0, 255)).astype(
+ np.uint8
+ )
else:
return (np.clip(image * 256 / FLAGS.rescale, 0, 255)).astype(np.uint8)
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
- X = target_vars['X']
- Y = target_vars['Y']
- X_NOISE = target_vars['X_NOISE']
- train_op = target_vars['train_op']
- energy_pos = target_vars['energy_pos']
- energy_neg = target_vars['energy_neg']
- loss_energy = target_vars['loss_energy']
- loss_ml = target_vars['loss_ml']
- loss_total = target_vars['total_loss']
- gvs = target_vars['gvs']
- x_grad = target_vars['x_grad']
- x_grad_first = target_vars['x_grad_first']
- x_off = target_vars['x_off']
- temp = target_vars['temp']
- x_mod = target_vars['x_mod']
- LABEL = target_vars['LABEL']
- LABEL_POS = target_vars['LABEL_POS']
- weights = target_vars['weights']
- test_x_mod = target_vars['test_x_mod']
- eps = target_vars['eps_begin']
- label_ent = target_vars['label_ent']
+ X = target_vars["X"]
+ Y = target_vars["Y"]
+ X_NOISE = target_vars["X_NOISE"]
+ train_op = target_vars["train_op"]
+ energy_pos = target_vars["energy_pos"]
+ energy_neg = target_vars["energy_neg"]
+ loss_energy = target_vars["loss_energy"]
+ loss_ml = target_vars["loss_ml"]
+ loss_total = target_vars["total_loss"]
+ gvs = target_vars["gvs"]
+ x_grad = target_vars["x_grad"]
+ x_grad_first = target_vars["x_grad_first"]
+ x_off = target_vars["x_off"]
+ temp = target_vars["temp"]
+ x_mod = target_vars["x_mod"]
+ LABEL = target_vars["LABEL"]
+ LABEL_POS = target_vars["LABEL_POS"]
+ weights = target_vars["weights"]
+ test_x_mod = target_vars["test_x_mod"]
+ eps = target_vars["eps_begin"]
+ label_ent = target_vars["label_ent"]
if FLAGS.use_attention:
- gamma = weights[0]['atten']['gamma']
+ gamma = weights[0]["atten"]["gamma"]
else:
gamma = tf.zeros(1)
@@ -206,13 +231,13 @@ def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
gamma,
x_grad_first,
label_ent,
- *gvs_dict.keys()]
+ *gvs_dict.keys(),
+ ]
output = [train_op, x_mod]
replay_buffer = ReplayBuffer(10000)
itr = resume_iter
x_mod = None
- gd_steps = 1
dataloader_iterator = iter(dataloader)
best_inception = 0.0
@@ -220,7 +245,7 @@ def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
for epoch in range(FLAGS.epoch_num):
for data_corrupt, data, label in dataloader:
data_corrupt = data_corrupt_init = data_corrupt.numpy()
- data_corrupt_init = data_corrupt.copy()
+ data_corrupt.copy()
data = data.numpy()
label = label.numpy()
@@ -239,10 +264,8 @@ def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
replay_batch = replay_buffer.sample(FLAGS.batch_size)
replay_batch = decompress_x_mod(replay_batch)
replay_mask = (
- np.random.uniform(
- 0,
- FLAGS.rescale,
- FLAGS.batch_size) > 0.05)
+ np.random.uniform(0, FLAGS.rescale, FLAGS.batch_size) > 0.05
+ )
data_corrupt[replay_mask] = replay_batch[replay_mask]
if FLAGS.pcd:
@@ -256,26 +279,40 @@ def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
feed_dict[LABEL_POS] = label_init
if itr % FLAGS.log_interval == 0:
- _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \
- grads = sess.run(log_output, feed_dict)
+ (
+ _,
+ e_pos,
+ e_neg,
+ eps,
+ loss_e,
+ loss_ml,
+ loss_total,
+ x_grad,
+ x_off,
+ x_mod,
+ gamma,
+ x_grad_first,
+ label_ent,
+ *grads,
+ ) = sess.run(log_output, feed_dict)
kvs = {}
- kvs['e_pos'] = e_pos.mean()
- kvs['e_pos_std'] = e_pos.std()
- kvs['e_neg'] = e_neg.mean()
- kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
- kvs['e_neg_std'] = e_neg.std()
- kvs['temp'] = temp
- kvs['loss_e'] = loss_e.mean()
- kvs['eps'] = eps.mean()
- kvs['label_ent'] = label_ent
- kvs['loss_ml'] = loss_ml.mean()
- kvs['loss_total'] = loss_total.mean()
- kvs['x_grad'] = np.abs(x_grad).mean()
- kvs['x_grad_first'] = np.abs(x_grad_first).mean()
- kvs['x_off'] = x_off.mean()
- kvs['iter'] = itr
- kvs['gamma'] = gamma
+ kvs["e_pos"] = e_pos.mean()
+ kvs["e_pos_std"] = e_pos.std()
+ kvs["e_neg"] = e_neg.mean()
+ kvs["e_diff"] = kvs["e_pos"] - kvs["e_neg"]
+ kvs["e_neg_std"] = e_neg.std()
+ kvs["temp"] = temp
+ kvs["loss_e"] = loss_e.mean()
+ kvs["eps"] = eps.mean()
+ kvs["label_ent"] = label_ent
+ kvs["loss_ml"] = loss_ml.mean()
+ kvs["loss_total"] = loss_total.mean()
+ kvs["x_grad"] = np.abs(x_grad).mean()
+ kvs["x_grad_first"] = np.abs(x_grad_first).mean()
+ kvs["x_off"] = x_off.mean()
+ kvs["iter"] = itr
+ kvs["gamma"] = gamma
for v, k in zip(grads, [v.name for v in gvs_dict.values()]):
kvs[k] = np.abs(v).max()
@@ -292,13 +329,14 @@ def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
if itr % FLAGS.save_interval == 0 and hvd.rank() == 0:
saver.save(
- sess,
- osp.join(
- FLAGS.logdir,
- FLAGS.exp,
- 'model_{}'.format(itr)))
+ sess, osp.join(FLAGS.logdir, FLAGS.exp, "model_{}".format(itr))
+ )
- if itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d':
+ if (
+ itr % FLAGS.test_interval == 0
+ and hvd.rank() == 0
+ and FLAGS.dataset != "2d"
+ ):
try_im = x_mod
orig_im = data_corrupt.squeeze()
actual_im = rescale_im(data)
@@ -307,18 +345,16 @@ def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
try_im = rescale_im(try_im).squeeze()
for i, (im, t_im, actual_im_i) in enumerate(
- zip(orig_im[:20], try_im[:20], actual_im)):
+ zip(orig_im[:20], try_im[:20], actual_im)
+ ):
shape = orig_im.shape[1:]
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
size = shape[1]
new_im[:, :size] = im
- new_im[:, size:2 * size] = t_im
- new_im[:, 2 * size:] = actual_im_i
+ new_im[:, size : 2 * size] = t_im
+ new_im[:, 2 * size :] = actual_im_i
- log_image(
- new_im, logger, 'train_gen_{}'.format(itr), step=i)
-
- test_im = x_mod
+ log_image(new_im, logger, "train_gen_{}".format(itr), step=i)
try:
data_corrupt, data, label = next(dataloader_iterator)
@@ -328,16 +364,21 @@ def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
data_corrupt = data_corrupt.numpy()
- if FLAGS.replay_batch and (
- x_mod is not None) and len(replay_buffer) > 0:
+ if (
+ FLAGS.replay_batch
+ and (x_mod is not None)
+ and len(replay_buffer) > 0
+ ):
replay_batch = replay_buffer.sample(FLAGS.batch_size)
replay_batch = decompress_x_mod(replay_batch)
- replay_mask = (
- np.random.uniform(
- 0, 1, (FLAGS.batch_size)) > 0.05)
+ replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.05
data_corrupt[replay_mask] = replay_batch[replay_mask]
- if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull':
+ if (
+ FLAGS.dataset == "cifar10"
+ or FLAGS.dataset == "imagenet"
+ or FLAGS.dataset == "imagenetfull"
+ ):
n = 128
if FLAGS.dataset == "imagenetfull":
@@ -345,19 +386,19 @@ def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
if len(replay_buffer) > n:
data_corrupt = decompress_x_mod(replay_buffer.sample(n))
- elif FLAGS.dataset == 'imagenetfull':
+ elif FLAGS.dataset == "imagenetfull":
data_corrupt = np.random.uniform(
- 0, FLAGS.rescale, (n, 128, 128, 3))
+ 0, FLAGS.rescale, (n, 128, 128, 3)
+ )
else:
data_corrupt = np.random.uniform(
- 0, FLAGS.rescale, (n, 32, 32, 3))
+ 0, FLAGS.rescale, (n, 32, 32, 3)
+ )
- if FLAGS.dataset == 'cifar10':
+ if FLAGS.dataset == "cifar10":
label = np.eye(10)[np.random.randint(0, 10, (n))]
else:
- label = np.eye(1000)[
- np.random.randint(
- 0, 1000, (n))]
+ label = np.eye(1000)[np.random.randint(0, 1000, (n))]
feed_dict[X_NOISE] = data_corrupt
@@ -376,63 +417,58 @@ def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
try_im = rescale_im(try_im).squeeze()
for i, (im, t_im, actual_im_i) in enumerate(
- zip(orig_im[:20], try_im[:20], actual_im)):
+ zip(orig_im[:20], try_im[:20], actual_im)
+ ):
shape = orig_im.shape[1:]
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
size = shape[1]
new_im[:, :size] = im
- new_im[:, size:2 * size] = t_im
- new_im[:, 2 * size:] = actual_im_i
- log_image(
- new_im, logger, 'val_gen_{}'.format(itr), step=i)
+ new_im[:, size : 2 * size] = t_im
+ new_im[:, 2 * size :] = actual_im_i
+ log_image(new_im, logger, "val_gen_{}".format(itr), step=i)
score, std = get_inception_score(list(try_im), splits=1)
- print(
- "Inception score of {} with std of {}".format(
- score, std))
+ print("Inception score of {} with std of {}".format(score, std))
kvs = {}
- kvs['inception_score'] = score
- kvs['inception_score_std'] = std
+ kvs["inception_score"] = score
+ kvs["inception_score_std"] = std
logger.writekvs(kvs)
if score > best_inception:
best_inception = score
- saver.save(
- sess,
- osp.join(
- FLAGS.logdir,
- FLAGS.exp,
- 'model_best'))
+ saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, "model_best"))
if itr > 60000 and FLAGS.dataset == "mnist":
assert False
itr += 1
- saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
+ saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, "model_{}".format(itr)))
-cifar10_map = {0: 'airplane',
- 1: 'automobile',
- 2: 'bird',
- 3: 'cat',
- 4: 'deer',
- 5: 'dog',
- 6: 'frog',
- 7: 'horse',
- 8: 'ship',
- 9: 'truck'}
+cifar10_map = {
+ 0: "airplane",
+ 1: "automobile",
+ 2: "bird",
+ 3: "cat",
+ 4: "deer",
+ 5: "dog",
+ 6: "frog",
+ 7: "horse",
+ 8: "ship",
+ 9: "truck",
+}
def test(target_vars, saver, sess, logger, dataloader):
- X_NOISE = target_vars['X_NOISE']
- X = target_vars['X']
- Y = target_vars['Y']
- LABEL = target_vars['LABEL']
- energy_start = target_vars['energy_start']
- x_mod = target_vars['x_mod']
- x_mod = target_vars['test_x_mod']
- energy_neg = target_vars['energy_neg']
+ X_NOISE = target_vars["X_NOISE"]
+ target_vars["X"]
+ Y = target_vars["Y"]
+ LABEL = target_vars["LABEL"]
+ energy_start = target_vars["energy_start"]
+ x_mod = target_vars["x_mod"]
+ x_mod = target_vars["test_x_mod"]
+ energy_neg = target_vars["energy_neg"]
np.random.seed(1)
random.seed(1)
@@ -447,49 +483,55 @@ def test(target_vars, saver, sess, logger, dataloader):
if FLAGS.cclass:
try_im, energy_orig, energy = sess.run(
- output, {X_NOISE: orig_im, Y: label[0:1], LABEL: label})
+ output, {X_NOISE: orig_im, Y: label[0:1], LABEL: label}
+ )
else:
try_im, energy_orig, energy = sess.run(
- output, {X_NOISE: orig_im, Y: label[0:1]})
+ output, {X_NOISE: orig_im, Y: label[0:1]}
+ )
orig_im = rescale_im(orig_im)
try_im = rescale_im(try_im)
actual_im = rescale_im(data)
for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate(
- zip(orig_im, energy_orig, try_im, energy, label, actual_im)):
+ zip(orig_im, energy_orig, try_im, energy, label, actual_im)
+ ):
label_i = np.array(label_i)
shape = im.shape[1:]
new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
size = shape[1]
new_im[:, :size] = im
- new_im[:, size:2 * size] = t_im
+ new_im[:, size : 2 * size] = t_im
if FLAGS.cclass:
label_i = np.where(label_i == 1)[0][0]
- if FLAGS.dataset == 'cifar10':
- log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format(
- i, energy_i[0], energy[0], cifar10_map[label_i]), step=i)
+ if FLAGS.dataset == "cifar10":
+ log_image(
+ new_im,
+ logger,
+ "{}_{:.4f}_now_{:.4f}_{}".format(
+ i, energy_i[0], energy[0], cifar10_map[label_i]
+ ),
+ step=i,
+ )
else:
log_image(
new_im,
logger,
- '{}_{:.4f}_now_{:.4f}_{}'.format(
- i,
- energy_i[0],
- energy[0],
- label_i),
- step=i)
+ "{}_{:.4f}_now_{:.4f}_{}".format(
+ i, energy_i[0], energy[0], label_i
+ ),
+ step=i,
+ )
else:
log_image(
new_im,
logger,
- '{}_{:.4f}_now_{:.4f}'.format(
- i,
- energy_i[0],
- energy[0]),
- step=i)
+ "{}_{:.4f}_now_{:.4f}".format(i, energy_i[0], energy[0]),
+ step=i,
+ )
test_ims = list(try_im)
real_ims = list(actual_im)
@@ -505,10 +547,12 @@ def test(target_vars, saver, sess, logger, dataloader):
if FLAGS.cclass:
try_im, energy_orig, energy = sess.run(
- output, {X_NOISE: data_corrupt, Y: label[0:1], LABEL: label})
+ output, {X_NOISE: data_corrupt, Y: label[0:1], LABEL: label}
+ )
else:
try_im, energy_orig, energy = sess.run(
- output, {X_NOISE: data_corrupt, Y: label[0:1]})
+ output, {X_NOISE: data_corrupt, Y: label[0:1]}
+ )
try_im = rescale_im(try_im)
real_im = rescale_im(data)
@@ -533,7 +577,7 @@ def main():
LABEL = None
print("Loading data...")
- if FLAGS.dataset == 'cifar10':
+ if FLAGS.dataset == "cifar10":
dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
channel_num = 3
@@ -544,24 +588,15 @@ def main():
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
if FLAGS.large_model:
- model = ResNet32Large(
- num_channels=channel_num,
- num_filters=128,
- train=True)
+ model = ResNet32Large(num_channels=channel_num, num_filters=128, train=True)
elif FLAGS.larger_model:
- model = ResNet32Larger(
- num_channels=channel_num,
- num_filters=128)
+ model = ResNet32Larger(num_channels=channel_num, num_filters=128)
elif FLAGS.wider_model:
- model = ResNet32Wider(
- num_channels=channel_num,
- num_filters=192)
+ model = ResNet32Wider(num_channels=channel_num, num_filters=192)
else:
- model = ResNet32(
- num_channels=channel_num,
- num_filters=128)
+ model = ResNet32(num_channels=channel_num, num_filters=128)
- elif FLAGS.dataset == 'imagenet':
+ elif FLAGS.dataset == "imagenet":
dataset = Imagenet(train=True)
test_dataset = Imagenet(train=False)
channel_num = 3
@@ -570,41 +605,34 @@ def main():
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
- model = ResNet32Wider(
- num_channels=channel_num,
- num_filters=256)
+ model = ResNet32Wider(num_channels=channel_num, num_filters=256)
- elif FLAGS.dataset == 'imagenetfull':
+ elif FLAGS.dataset == "imagenetfull":
channel_num = 3
X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
- model = ResNet128(
- num_channels=channel_num,
- num_filters=64)
+ model = ResNet128(num_channels=channel_num, num_filters=64)
- elif FLAGS.dataset == 'mnist':
+ elif FLAGS.dataset == "mnist":
dataset = Mnist(rescale=FLAGS.rescale)
- test_dataset = dataset
channel_num = 1
X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
- model = MnistNet(
- num_channels=channel_num,
- num_filters=FLAGS.num_filters)
+ model = MnistNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
- elif FLAGS.dataset == 'dsprites':
+ elif FLAGS.dataset == "dsprites":
dataset = DSprites(
cond_shape=FLAGS.cond_shape,
cond_size=FLAGS.cond_size,
cond_pos=FLAGS.cond_pos,
- cond_rot=FLAGS.cond_rot)
- test_dataset = dataset
+ cond_rot=FLAGS.cond_rot,
+ )
channel_num = 1
X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
@@ -641,24 +669,28 @@ def main():
cond_size=FLAGS.cond_size,
cond_shape=FLAGS.cond_shape,
cond_pos=FLAGS.cond_pos,
- cond_rot=FLAGS.cond_rot)
+ cond_rot=FLAGS.cond_rot,
+ )
print("Done loading...")
if FLAGS.dataset == "imagenetfull":
# In the case of full imagenet, use custom_tensorflow dataloader
- data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale)
+ data_loader = TFImagenetLoader(
+ "train", FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale
+ )
else:
data_loader = DataLoader(
dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.data_workers,
drop_last=True,
- shuffle=True)
+ shuffle=True,
+ )
- batch_size = FLAGS.batch_size
+ FLAGS.batch_size
- weights = [model.construct_weights('context_0')]
+ weights = [model.construct_weights("context_0")]
Y = tf.placeholder(shape=(None), dtype=tf.int32)
@@ -667,10 +699,8 @@ def main():
X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
- LABEL_SPLIT_INIT = list(LABEL_SPLIT)
+ list(LABEL_SPLIT)
tower_grads = []
- tower_gen_grads = []
- x_mod_list = []
optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
optimizer = hvd.DistributedOptimizer(optimizer)
@@ -683,26 +713,30 @@ def main():
tf.convert_to_tensor(
np.reshape(
np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
- (FLAGS.batch_size * 10, 10)),
- dtype=tf.float32),
+ (FLAGS.batch_size * 10, 10),
+ ),
+ dtype=tf.float32,
+ ),
trainable=False,
- dtype=tf.float32)
+ dtype=tf.float32,
+ )
x_split = tf.tile(
- tf.reshape(
- X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1))
+ tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1)
+ )
x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
energy_pos = model.forward(
- x_split,
- weights[0],
- label=label_tensor,
- stop_at_grad=False)
+ x_split, weights[0], label=label_tensor, stop_at_grad=False
+ )
energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
energy_partition_est = tf.reduce_logsumexp(
- energy_pos_full, axis=1, keepdims=True)
+ energy_pos_full, axis=1, keepdims=True
+ )
uniform = tf.random_uniform(tf.shape(energy_pos_full))
- label_tensor = tf.argmax(-energy_pos_full -
- tf.log(-tf.log(uniform)) - energy_partition_est, axis=1)
+ label_tensor = tf.argmax(
+ -energy_pos_full - tf.log(-tf.log(uniform)) - energy_partition_est,
+ axis=1,
+ )
label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
label = tf.Print(label, [label_tensor, energy_pos_full])
LABEL_SPLIT[j] = label
@@ -710,10 +744,9 @@ def main():
else:
energy_pos = [
model.forward(
- X_SPLIT[j],
- weights[0],
- label=LABEL_POS_SPLIT[j],
- stop_at_grad=False)]
+ X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False
+ )
+ ]
energy_pos = tf.concat(energy_pos, axis=0)
print("Building graph...")
@@ -722,42 +755,57 @@ def main():
x_grads = []
energy_negs = []
- loss_energys = []
- energy_negs.extend([model.forward(tf.stop_gradient(
- x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)])
+ energy_negs.extend(
+ [
+ model.forward(
+ tf.stop_gradient(x_mod),
+ weights[0],
+ label=LABEL_SPLIT[j],
+ stop_at_grad=False,
+ reuse=True,
+ )
+ ]
+ )
eps_begin = tf.zeros(1)
steps = tf.constant(0)
- c = lambda i, x: tf.less(i, FLAGS.num_steps)
+
+ def c(i, x):
+ return tf.less(i, FLAGS.num_steps)
def langevin_step(counter, x_mod):
- x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
- mean=0.0,
- stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)
+ x_mod = x_mod + tf.random_normal(
+ tf.shape(x_mod),
+ mean=0.0,
+ stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale,
+ )
energy_noise = energy_start = tf.concat(
- [model.forward(
+ [
+ model.forward(
x_mod,
weights[0],
label=LABEL_SPLIT[j],
reuse=True,
stop_at_grad=False,
- stop_batch=True)],
- axis=0)
+ stop_batch=True,
+ )
+ ],
+ axis=0,
+ )
x_grad, label_grad = tf.gradients(
- FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]])
- energy_noise_old = energy_noise
+ FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]]
+ )
lr = FLAGS.step_lr
if FLAGS.proj_norm != 0.0:
- if FLAGS.proj_norm_type == 'l2':
+ if FLAGS.proj_norm_type == "l2":
x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
- elif FLAGS.proj_norm_type == 'li':
- x_grad = tf.clip_by_value(
- x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
+ elif FLAGS.proj_norm_type == "li":
+ x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
else:
print("Other types of projection are not supported!!!")
assert False
@@ -766,10 +814,11 @@ def main():
if FLAGS.hmc:
# Step size should be tuned to get around 65% acceptance
def energy(x):
- return FLAGS.temperature * \
- model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)
+ return FLAGS.temperature * model.forward(
+ x, weights[0], label=LABEL_SPLIT[j], reuse=True
+ )
- x_last = hmc(x_mod, 15., 10, energy)
+ x_last = hmc(x_mod, 15.0, 10, energy)
else:
x_last = x_mod - (lr) * x_grad
@@ -782,8 +831,9 @@ def main():
steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))
- energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j],
- stop_at_grad=False, reuse=True)
+ energy_eval = model.forward(
+ x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True
+ )
x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
x_grads.append(x_grad)
@@ -793,22 +843,20 @@ def main():
weights[0],
label=LABEL_SPLIT[j],
stop_at_grad=False,
- reuse=True))
+ reuse=True,
+ )
+ )
test_x_mod = x_mod
temp = FLAGS.temperature
energy_neg = energy_negs[-1]
- x_off = tf.reduce_mean(
- tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))
+ x_off = tf.reduce_mean(tf.abs(x_mod[: tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))
loss_energy = model.forward(
- x_mod,
- weights[0],
- reuse=True,
- label=LABEL,
- stop_grad=True)
+ x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True
+ )
print("Finished processing loop construction ...")
@@ -817,38 +865,40 @@ def main():
if FLAGS.cclass or FLAGS.model_cclass:
label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
label_prob = label_sum / tf.reduce_sum(label_sum)
- label_ent = -tf.reduce_sum(label_prob *
- tf.math.log(label_prob + 1e-7))
+ label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7))
else:
label_ent = tf.zeros(1)
- target_vars['label_ent'] = label_ent
+ target_vars["label_ent"] = label_ent
if FLAGS.train:
- if FLAGS.objective == 'logsumexp':
- pos_term = temp * energy_pos
- energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
+ if FLAGS.objective == "logsumexp":
+ temp * energy_pos
+ energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg)
coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
pos_loss = tf.reduce_mean(temp * energy_pos)
neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
- elif FLAGS.objective == 'cd':
+ elif FLAGS.objective == "cd":
pos_loss = tf.reduce_mean(temp * energy_pos)
neg_loss = -tf.reduce_mean(temp * energy_neg)
loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
- elif FLAGS.objective == 'softplus':
- loss_ml = FLAGS.ml_coeff * \
- tf.nn.softplus(temp * (energy_pos - energy_neg))
+ elif FLAGS.objective == "softplus":
+ loss_ml = FLAGS.ml_coeff * tf.nn.softplus(
+ temp * (energy_pos - energy_neg)
+ )
loss_total = tf.reduce_mean(loss_ml)
if not FLAGS.zero_kl:
loss_total = loss_total + tf.reduce_mean(loss_energy)
- loss_total = loss_total + \
- FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))
+ loss_total = loss_total + FLAGS.l2_coeff * (
+ tf.reduce_mean(tf.square(energy_pos))
+ + tf.reduce_mean(tf.square((energy_neg)))
+ )
print("Started gradient computation...")
gvs = optimizer.compute_gradients(loss_total)
@@ -860,38 +910,38 @@ def main():
print("Finished applying gradients.")
- target_vars['loss_ml'] = loss_ml
- target_vars['total_loss'] = loss_total
- target_vars['loss_energy'] = loss_energy
- target_vars['weights'] = weights
- target_vars['gvs'] = gvs
+ target_vars["loss_ml"] = loss_ml
+ target_vars["total_loss"] = loss_total
+ target_vars["loss_energy"] = loss_energy
+ target_vars["weights"] = weights
+ target_vars["gvs"] = gvs
- target_vars['X'] = X
- target_vars['Y'] = Y
- target_vars['LABEL'] = LABEL
- target_vars['LABEL_POS'] = LABEL_POS
- target_vars['X_NOISE'] = X_NOISE
- target_vars['energy_pos'] = energy_pos
- target_vars['energy_start'] = energy_negs[0]
+ target_vars["X"] = X
+ target_vars["Y"] = Y
+ target_vars["LABEL"] = LABEL
+ target_vars["LABEL_POS"] = LABEL_POS
+ target_vars["X_NOISE"] = X_NOISE
+ target_vars["energy_pos"] = energy_pos
+ target_vars["energy_start"] = energy_negs[0]
if len(x_grads) >= 1:
- target_vars['x_grad'] = x_grads[-1]
- target_vars['x_grad_first'] = x_grads[0]
+ target_vars["x_grad"] = x_grads[-1]
+ target_vars["x_grad_first"] = x_grads[0]
else:
- target_vars['x_grad'] = tf.zeros(1)
- target_vars['x_grad_first'] = tf.zeros(1)
+ target_vars["x_grad"] = tf.zeros(1)
+ target_vars["x_grad_first"] = tf.zeros(1)
- target_vars['x_mod'] = x_mod
- target_vars['x_off'] = x_off
- target_vars['temp'] = temp
- target_vars['energy_neg'] = energy_neg
- target_vars['test_x_mod'] = test_x_mod
- target_vars['eps_begin'] = eps_begin
+ target_vars["x_mod"] = x_mod
+ target_vars["x_off"] = x_off
+ target_vars["temp"] = temp
+ target_vars["energy_neg"] = energy_neg
+ target_vars["test_x_mod"] = test_x_mod
+ target_vars["eps_begin"] = eps_begin
if FLAGS.train:
grads = average_gradients(tower_grads)
train_op = optimizer.apply_gradients(grads)
- target_vars['train_op'] = train_op
+ target_vars["train_op"] = train_op
config = tf.ConfigProto()
@@ -900,8 +950,7 @@ def main():
sess = tf.Session(config=config)
- saver = loader = tf.train.Saver(
- max_to_keep=30, keep_checkpoint_every_n_hours=6)
+ saver = loader = tf.train.Saver(max_to_keep=30, keep_checkpoint_every_n_hours=6)
total_parameters = 0
for variable in tf.trainable_variables():
@@ -918,7 +967,7 @@ def main():
resume_itr = 0
if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
- model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
+ model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
resume_itr = FLAGS.resume_iter
# saver.restore(sess, model_file)
optimistic_restore(sess, model_file)
@@ -930,9 +979,7 @@ def main():
print("End broadcast")
if FLAGS.train:
- train(target_vars, saver, sess,
- logger, data_loader, resume_itr,
- logdir)
+ train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir)
test(target_vars, saver, sess, logger, data_loader)
diff --git a/EBMs/utils.py b/EBMs/utils.py
index c53c73f..5cea716 100644
--- a/EBMs/utils.py
+++ b/EBMs/utils.py
@@ -1,19 +1,21 @@
-""" Utility functions. """
-import numpy as np
+"""Utility functions."""
+
import os
import random
-import tensorflow as tf
-import warnings
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib.framework import sort
from tensorflow.contrib.layers.python import layers as tf_layers
from tensorflow.python.platform import flags
-from tensorflow.contrib.framework import sort
FLAGS = flags.FLAGS
-flags.DEFINE_integer('spec_iter', 1, 'Number of iterations to normalize spectrum of matrix')
-flags.DEFINE_float('spec_norm_val', 1.0, 'Desired norm of matrices')
-flags.DEFINE_bool('downsample', False, 'Wheter to do average pool downsampling')
-flags.DEFINE_bool('spec_eval', False, 'Set to true to prevent spectral updates')
+flags.DEFINE_integer(
+ "spec_iter", 1, "Number of iterations to normalize spectrum of matrix"
+)
+flags.DEFINE_float("spec_norm_val", 1.0, "Desired norm of matrices")
+flags.DEFINE_bool("downsample", False, "Wheter to do average pool downsampling")
+flags.DEFINE_bool("spec_eval", False, "Set to true to prevent spectral updates")
def get_median(v):
@@ -23,10 +25,11 @@ def get_median(v):
def set_seed(seed):
- import torch
- import numpy
import random
+ import numpy
+ import torch
+
torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)
@@ -59,12 +62,11 @@ class ReplayBuffer(object):
self._storage.extend(list(ims))
else:
if batch_size + self._next_idx < self._maxsize:
- self._storage[self._next_idx:self._next_idx +
- batch_size] = list(ims)
+ self._storage[self._next_idx : self._next_idx + batch_size] = list(ims)
else:
split_idx = self._maxsize - self._next_idx
- self._storage[self._next_idx:] = list(ims)[:split_idx]
- self._storage[:batch_size - split_idx] = list(ims)[split_idx:]
+ self._storage[self._next_idx :] = list(ims)[:split_idx]
+ self._storage[: batch_size - split_idx] = list(ims)[split_idx:]
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
def _encode_sample(self, idxes):
@@ -93,74 +95,87 @@ class ReplayBuffer(object):
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
"""
- idxes = [random.randint(0, len(self._storage) - 1)
- for _ in range(batch_size)]
+ idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
return self._encode_sample(idxes)
def get_weight(
- name,
- shape,
- gain=np.sqrt(2),
- use_wscale=False,
- fan_in=None,
- spec_norm=False,
- zero=False,
- fc=False):
+ name,
+ shape,
+ gain=np.sqrt(2),
+ use_wscale=False,
+ fan_in=None,
+ spec_norm=False,
+ zero=False,
+ fc=False,
+):
if fan_in is None:
fan_in = np.prod(shape[:-1])
std = gain / np.sqrt(fan_in) # He init
if use_wscale:
- wscale = tf.constant(np.float32(std), name=name + 'wscale')
- var = tf.get_variable(
- name + 'weight',
- shape=shape,
- initializer=tf.initializers.random_normal()) * wscale
+ wscale = tf.constant(np.float32(std), name=name + "wscale")
+ var = (
+ tf.get_variable(
+ name + "weight",
+ shape=shape,
+ initializer=tf.initializers.random_normal(),
+ )
+ * wscale
+ )
elif spec_norm:
if zero:
var = tf.get_variable(
shape=shape,
- name=name + 'weight',
- initializer=tf.initializers.random_normal(
- stddev=1e-10))
+ name=name + "weight",
+ initializer=tf.initializers.random_normal(stddev=1e-10),
+ )
var = spectral_normed_weight(var, name, lower_bound=True, fc=fc)
else:
var = tf.get_variable(
- name + 'weight',
+ name + "weight",
shape=shape,
- initializer=tf.initializers.random_normal())
+ initializer=tf.initializers.random_normal(),
+ )
var = spectral_normed_weight(var, name, fc=fc)
else:
if zero:
var = tf.get_variable(
- name + 'weight',
- shape=shape,
- initializer=tf.initializers.zero())
+ name + "weight", shape=shape, initializer=tf.initializers.zero()
+ )
else:
var = tf.get_variable(
- name + 'weight',
+ name + "weight",
shape=shape,
- initializer=tf.contrib.layers.xavier_initializer(
- dtype=tf.float32))
+ initializer=tf.contrib.layers.xavier_initializer(dtype=tf.float32),
+ )
return var
def pixel_norm(x, epsilon=1e-8):
- with tf.variable_scope('PixelNorm'):
- return x * tf.rsqrt(tf.reduce_mean(tf.square(x),
- axis=[1, 2], keepdims=True) + epsilon)
+ with tf.variable_scope("PixelNorm"):
+ return x * tf.rsqrt(
+ tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) + epsilon
+ )
# helper
def get_images(paths, labels, nb_samples=None, shuffle=True):
if nb_samples is not None:
- def sampler(x): return random.sample(x, nb_samples)
+
+ def sampler(x):
+ return random.sample(x, nb_samples)
+
else:
- def sampler(x): return x
- images = [(i, os.path.join(path, image))
- for i, path in zip(labels, paths)
- for image in sampler(os.listdir(path))]
+
+ def sampler(x):
+ return x
+
+ images = [
+ (i, os.path.join(path, image))
+ for i, path in zip(labels, paths)
+ for image in sampler(os.listdir(path))
+ ]
if shuffle:
random.shuffle(images)
return images
@@ -170,10 +185,15 @@ def optimistic_restore(session, save_file, v_prefix=None):
reader = tf.train.NewCheckpointReader(save_file)
saved_shapes = reader.get_variable_to_shape_map()
- var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.get_collection(
- tf.GraphKeys.GLOBAL_VARIABLES) if var.name.split(':')[0] in saved_shapes])
+ var_names = sorted(
+ [
+ (var.name, var.name.split(":")[0])
+ for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
+ if var.name.split(":")[0] in saved_shapes
+ ]
+ )
restore_vars = []
- with tf.variable_scope('', reuse=True):
+ with tf.variable_scope("", reuse=True):
for var_name, saved_var_name in var_names:
try:
curr_var = tf.get_variable(saved_var_name)
@@ -196,18 +216,28 @@ def optimistic_remap_restore(session, save_file, v_prefix):
saved_shapes = reader.get_variable_to_shape_map()
vars_list = tf.get_collection(
- tf.GraphKeys.GLOBAL_VARIABLES,
- scope='context_{}'.format(v_prefix))
- var_names = sorted([(var.name.split(':')[0], var) for var in vars_list if (
- (var.name.split(':')[0]).replace('context_{}'.format(v_prefix), 'context_0') in saved_shapes)])
- restore_vars = []
+ tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(v_prefix)
+ )
+ var_names = sorted(
+ [
+ (var.name.split(":")[0], var)
+ for var in vars_list
+ if (
+ (var.name.split(":")[0]).replace(
+ "context_{}".format(v_prefix), "context_0"
+ )
+ in saved_shapes
+ )
+ ]
+ )
v_map = {}
- with tf.variable_scope('', reuse=True):
+ with tf.variable_scope("", reuse=True):
for saved_var_name, curr_var in var_names:
var_shape = curr_var.get_shape().as_list()
saved_var_name = saved_var_name.replace(
- 'context_{}'.format(v_prefix), 'context_0')
+ "context_{}".format(v_prefix), "context_0"
+ )
if var_shape == saved_shapes[saved_var_name]:
v_map[saved_var_name] = curr_var
else:
@@ -221,10 +251,15 @@ def optimistic_remap_restore(session, save_file, v_prefix):
def remap_restore(session, save_file, i):
reader = tf.train.NewCheckpointReader(save_file)
saved_shapes = reader.get_variable_to_shape_map()
- var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
- if var.name.split(':')[0] in saved_shapes])
+ var_names = sorted(
+ [
+ (var.name, var.name.split(":")[0])
+ for var in tf.global_variables()
+ if var.name.split(":")[0] in saved_shapes
+ ]
+ )
restore_vars = []
- with tf.variable_scope('', reuse=True):
+ with tf.variable_scope("", reuse=True):
for var_name, saved_var_name in var_names:
try:
curr_var = tf.get_variable(saved_var_name)
@@ -241,15 +276,8 @@ def remap_restore(session, save_file, i):
# Network weight initializers
def init_conv_weight(
- weights,
- scope,
- k,
- c_in,
- c_out,
- spec_norm=True,
- zero=False,
- scale=1.0,
- classes=1):
+ weights, scope, k, c_in, c_out, spec_norm=True, zero=False, scale=1.0, classes=1
+):
if spec_norm:
spec_norm = FLAGS.spec_norm
@@ -257,50 +285,43 @@ def init_conv_weight(
conv_weights = {}
with tf.variable_scope(scope):
if zero:
- conv_weights['c'] = get_weight(
- 'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True)
+ conv_weights["c"] = get_weight(
+ "c", [k, k, c_in, c_out], spec_norm=spec_norm, zero=True
+ )
else:
- conv_weights['c'] = get_weight(
- 'c', [k, k, c_in, c_out], spec_norm=spec_norm)
+ conv_weights["c"] = get_weight(
+ "c", [k, k, c_in, c_out], spec_norm=spec_norm
+ )
- conv_weights['b'] = tf.get_variable(
- shape=[c_out], name='b', initializer=tf.initializers.zeros())
+ conv_weights["b"] = tf.get_variable(
+ shape=[c_out], name="b", initializer=tf.initializers.zeros()
+ )
if classes != 1:
- conv_weights['g'] = tf.get_variable(
- shape=[
- classes,
- c_out],
- name='g',
- initializer=tf.initializers.ones())
- conv_weights['gb'] = tf.get_variable(
- shape=[
- classes,
- c_in],
- name='gb',
- initializer=tf.initializers.zeros())
+ conv_weights["g"] = tf.get_variable(
+ shape=[classes, c_out], name="g", initializer=tf.initializers.ones()
+ )
+ conv_weights["gb"] = tf.get_variable(
+ shape=[classes, c_in], name="gb", initializer=tf.initializers.zeros()
+ )
else:
- conv_weights['g'] = tf.get_variable(
- shape=[c_out], name='g', initializer=tf.initializers.ones())
- conv_weights['gb'] = tf.get_variable(
- shape=[c_in], name='gb', initializer=tf.initializers.zeros())
+ conv_weights["g"] = tf.get_variable(
+ shape=[c_out], name="g", initializer=tf.initializers.ones()
+ )
+ conv_weights["gb"] = tf.get_variable(
+ shape=[c_in], name="gb", initializer=tf.initializers.zeros()
+ )
- conv_weights['cb'] = tf.get_variable(
- shape=[c_in], name='cb', initializer=tf.initializers.zeros())
+ conv_weights["cb"] = tf.get_variable(
+ shape=[c_in], name="cb", initializer=tf.initializers.zeros()
+ )
weights[scope] = conv_weights
def init_convt_weight(
- weights,
- scope,
- k,
- c_in,
- c_out,
- spec_norm=True,
- zero=False,
- scale=1.0,
- classes=1):
+ weights, scope, k, c_in, c_out, spec_norm=True, zero=False, scale=1.0, classes=1
+):
if spec_norm:
spec_norm = FLAGS.spec_norm
@@ -308,67 +329,66 @@ def init_convt_weight(
conv_weights = {}
with tf.variable_scope(scope):
if zero:
- conv_weights['c'] = get_weight(
- 'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True)
+ conv_weights["c"] = get_weight(
+ "c", [k, k, c_in, c_out], spec_norm=spec_norm, zero=True
+ )
else:
- conv_weights['c'] = get_weight(
- 'c', [k, k, c_in, c_out], spec_norm=spec_norm)
+ conv_weights["c"] = get_weight(
+ "c", [k, k, c_in, c_out], spec_norm=spec_norm
+ )
- conv_weights['b'] = tf.get_variable(
- shape=[c_in], name='b', initializer=tf.initializers.zeros())
+ conv_weights["b"] = tf.get_variable(
+ shape=[c_in], name="b", initializer=tf.initializers.zeros()
+ )
if classes != 1:
- conv_weights['g'] = tf.get_variable(
- shape=[
- classes,
- c_in],
- name='g',
- initializer=tf.initializers.ones())
- conv_weights['gb'] = tf.get_variable(
- shape=[
- classes,
- c_out],
- name='gb',
- initializer=tf.initializers.zeros())
+ conv_weights["g"] = tf.get_variable(
+ shape=[classes, c_in], name="g", initializer=tf.initializers.ones()
+ )
+ conv_weights["gb"] = tf.get_variable(
+ shape=[classes, c_out], name="gb", initializer=tf.initializers.zeros()
+ )
else:
- conv_weights['g'] = tf.get_variable(
- shape=[c_in], name='g', initializer=tf.initializers.ones())
- conv_weights['gb'] = tf.get_variable(
- shape=[c_out], name='gb', initializer=tf.initializers.zeros())
+ conv_weights["g"] = tf.get_variable(
+ shape=[c_in], name="g", initializer=tf.initializers.ones()
+ )
+ conv_weights["gb"] = tf.get_variable(
+ shape=[c_out], name="gb", initializer=tf.initializers.zeros()
+ )
- conv_weights['cb'] = tf.get_variable(
- shape=[c_in], name='cb', initializer=tf.initializers.zeros())
+ conv_weights["cb"] = tf.get_variable(
+ shape=[c_in], name="cb", initializer=tf.initializers.zeros()
+ )
weights[scope] = conv_weights
def init_attention_weight(
- weights,
- scope,
- c_in,
- k,
- trainable_gamma=True,
- spec_norm=True):
+ weights, scope, c_in, k, trainable_gamma=True, spec_norm=True
+):
if spec_norm:
spec_norm = FLAGS.spec_norm
atten_weights = {}
with tf.variable_scope(scope):
- atten_weights['q'] = get_weight(
- 'atten_q', [1, 1, c_in, k], spec_norm=spec_norm)
- atten_weights['q_b'] = tf.get_variable(
- shape=[k], name='atten_q_b1', initializer=tf.initializers.zeros())
- atten_weights['k'] = get_weight(
- 'atten_k', [1, 1, c_in, k], spec_norm=spec_norm)
- atten_weights['k_b'] = tf.get_variable(
- shape=[k], name='atten_k_b1', initializer=tf.initializers.zeros())
- atten_weights['v'] = get_weight(
- 'atten_v', [1, 1, c_in, c_in], spec_norm=spec_norm)
- atten_weights['v_b'] = tf.get_variable(
- shape=[c_in], name='atten_v_b1', initializer=tf.initializers.zeros())
- atten_weights['gamma'] = tf.get_variable(
- shape=[1], name='gamma', initializer=tf.initializers.zeros())
+ atten_weights["q"] = get_weight("atten_q", [1, 1, c_in, k], spec_norm=spec_norm)
+ atten_weights["q_b"] = tf.get_variable(
+ shape=[k], name="atten_q_b1", initializer=tf.initializers.zeros()
+ )
+ atten_weights["k"] = get_weight("atten_k", [1, 1, c_in, k], spec_norm=spec_norm)
+ atten_weights["k_b"] = tf.get_variable(
+ shape=[k], name="atten_k_b1", initializer=tf.initializers.zeros()
+ )
+ atten_weights["v"] = get_weight(
+ "atten_v", [1, 1, c_in, c_in], spec_norm=spec_norm
+ )
+ atten_weights["v_b"] = tf.get_variable(
+ shape=[c_in], name="atten_v_b1", initializer=tf.initializers.zeros()
+ )
+ atten_weights["gamma"] = tf.get_variable(
+ shape=[1], name="gamma", initializer=tf.initializers.zeros()
+ )
weights[scope] = atten_weights
@@ -380,24 +400,25 @@ def init_fc_weight(weights, scope, c_in, c_out, spec_norm=True):
spec_norm = FLAGS.spec_norm
with tf.variable_scope(scope):
- fc_weights['w'] = get_weight(
- 'w', [c_in, c_out], spec_norm=spec_norm, fc=True)
- fc_weights['b'] = tf.get_variable(
- shape=[c_out], name='b', initializer=tf.initializers.zeros())
+ fc_weights["w"] = get_weight("w", [c_in, c_out], spec_norm=spec_norm, fc=True)
+ fc_weights["b"] = tf.get_variable(
+ shape=[c_out], name="b", initializer=tf.initializers.zeros()
+ )
weights[scope] = fc_weights
def init_res_weight(
- weights,
- scope,
- k,
- c_in,
- c_out,
- hidden_dim=None,
- spec_norm=True,
- res_scale=1.0,
- classes=1):
+ weights,
+ scope,
+ k,
+ c_in,
+ c_out,
+ hidden_dim=None,
+ spec_norm=True,
+ res_scale=1.0,
+ classes=1,
+):
if not hidden_dim:
hidden_dim = c_in
@@ -407,36 +428,38 @@ def init_res_weight(
init_conv_weight(
weights,
- scope +
- '_res_c1',
+ scope + "_res_c1",
k,
c_in,
c_out,
spec_norm=spec_norm,
scale=res_scale,
- classes=classes)
+ classes=classes,
+ )
init_conv_weight(
weights,
- scope + '_res_c2',
+ scope + "_res_c2",
k,
c_out,
c_out,
spec_norm=spec_norm,
zero=True,
scale=res_scale,
- classes=classes)
+ classes=classes,
+ )
if c_in != c_out:
init_conv_weight(
weights,
- scope +
- '_res_adaptive',
+ scope + "_res_adaptive",
k,
c_in,
c_out,
spec_norm=spec_norm,
scale=res_scale,
- classes=classes)
+ classes=classes,
+ )
+
# Network forward helpers
@@ -445,32 +468,28 @@ def smart_conv_block(inp, weights, reuse, scope, use_stride=True, **kwargs):
weights = weights[scope]
return conv_block(
inp,
- weights['c'],
- weights['b'],
+ weights["c"],
+ weights["b"],
reuse,
scope,
- scale=weights['g'],
- bias=weights['gb'],
- class_bias=weights['cb'],
+ scale=weights["g"],
+ bias=weights["gb"],
+ class_bias=weights["cb"],
use_stride=use_stride,
- **kwargs)
+ **kwargs,
+ )
def smart_convt_block(
- inp,
- weights,
- reuse,
- scope,
- output_dim,
- upsample=True,
- label=None):
+ inp, weights, reuse, scope, output_dim, upsample=True, label=None
+):
weights = weights[scope]
- cweight = weights['c']
- bweight = weights['b']
- scale = weights['g']
- bias = weights['gb']
- class_bias = weights['cb']
+ cweight = weights["c"]
+ weights["b"]
+ scale = weights["g"]
+ bias = weights["gb"]
+ class_bias = weights["cb"]
if upsample:
stride = [1, 2, 2, 1]
@@ -485,15 +504,14 @@ def smart_convt_block(
inp = inp + bias
- shape = cweight.get_shape()
- conv_output = tf.nn.conv2d_transpose(inp,
- cweight,
- [tf.shape(inp)[0],
- output_dim,
- output_dim,
- cweight.get_shape().as_list()[-2]],
- stride,
- 'SAME')
+ cweight.get_shape()
+ conv_output = tf.nn.conv2d_transpose(
+ inp,
+ cweight,
+ [tf.shape(inp)[0], output_dim, output_dim, cweight.get_shape().as_list()[-2]],
+ stride,
+ "SAME",
+ )
if label is not None:
scale_batch = tf.matmul(label, scale) + class_bias
@@ -509,31 +527,33 @@ def smart_convt_block(
def smart_res_block(
- inp,
- weights,
- reuse,
- scope,
- downsample=True,
- adaptive=True,
- stop_batch=False,
- upsample=False,
- label=None,
- act=tf.nn.leaky_relu,
- dropout=False,
- train=False,
- **kwargs):
- gn1 = weights[scope + '_res_c1']
- gn2 = weights[scope + '_res_c2']
+ inp,
+ weights,
+ reuse,
+ scope,
+ downsample=True,
+ adaptive=True,
+ stop_batch=False,
+ upsample=False,
+ label=None,
+ act=tf.nn.leaky_relu,
+ dropout=False,
+ train=False,
+ **kwargs,
+):
+ weights[scope + "_res_c1"]
+ weights[scope + "_res_c2"]
c1 = smart_conv_block(
inp,
weights,
reuse,
- scope + '_res_c1',
+ scope + "_res_c1",
use_stride=False,
activation=None,
extra_bias=True,
label=label,
- **kwargs)
+ **kwargs,
+ )
if dropout:
c1 = tf.layers.dropout(c1, rate=0.5, training=train)
@@ -543,36 +563,38 @@ def smart_res_block(
c1,
weights,
reuse,
- scope + '_res_c2',
+ scope + "_res_c2",
use_stride=False,
activation=None,
use_scale=True,
extra_bias=True,
label=label,
- **kwargs)
+ **kwargs,
+ )
if adaptive:
c_bypass = smart_conv_block(
inp,
weights,
reuse,
- scope +
- '_res_adaptive',
+ scope + "_res_adaptive",
use_stride=False,
activation=None,
- **kwargs)
+ **kwargs,
+ )
else:
c_bypass = inp
res = c2 + c_bypass
if upsample:
- res_shape = tf.shape(res)
+ tf.shape(res)
res_shape_list = res.get_shape()
res = tf.image.resize_nearest_neighbor(
- res, [2 * res_shape_list[1], 2 * res_shape_list[2]])
+ res, [2 * res_shape_list[1], 2 * res_shape_list[2]]
+ )
elif downsample:
- res = tf.nn.avg_pool(res, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
+ res = tf.nn.avg_pool(res, (1, 2, 2, 1), (1, 2, 2, 1), "VALID")
res = act(res)
@@ -584,33 +606,35 @@ def smart_res_block_optim(inp, weights, reuse, scope, **kwargs):
inp,
weights,
reuse,
- scope + '_res_c1',
+ scope + "_res_c1",
use_stride=False,
activation=None,
- **kwargs)
+ **kwargs,
+ )
c1 = tf.nn.leaky_relu(c1)
c2 = smart_conv_block(
c1,
weights,
reuse,
- scope + '_res_c2',
+ scope + "_res_c2",
use_stride=False,
activation=None,
- **kwargs)
+ **kwargs,
+ )
- inp = tf.nn.avg_pool(inp, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
+ inp = tf.nn.avg_pool(inp, (1, 2, 2, 1), (1, 2, 2, 1), "VALID")
c_bypass = smart_conv_block(
inp,
weights,
reuse,
- scope +
- '_res_adaptive',
+ scope + "_res_adaptive",
use_stride=False,
activation=None,
- **kwargs)
- c2 = tf.nn.avg_pool(c2, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
+ **kwargs,
+ )
+ c2 = tf.nn.avg_pool(c2, (1, 2, 2, 1), (1, 2, 2, 1), "VALID")
- res = c2 + c_bypass
+ c2 + c_bypass
return c2
@@ -621,6 +645,7 @@ def groupsort(k=4):
inp = sort(tf.reshape(inp, (-1, 4)))
inp = tf.reshape(inp, old_shape)
return inp
+
return sortact
@@ -628,52 +653,54 @@ def smart_atten_block(inp, weights, reuse, scope, **kwargs):
w = weights[scope]
return attention(
inp,
- w['q'],
- w['q_b'],
- w['k'],
- w['k_b'],
- w['v'],
- w['v_b'],
- w['gamma'],
+ w["q"],
+ w["q_b"],
+ w["k"],
+ w["k_b"],
+ w["v"],
+ w["v_b"],
+ w["gamma"],
reuse,
scope,
- **kwargs)
+ **kwargs,
+ )
def smart_fc_block(inp, weights, reuse, scope, use_bias=True):
weights = weights[scope]
- output = tf.matmul(inp, weights['w'])
+ output = tf.matmul(inp, weights["w"])
if use_bias:
- output = output + weights['b']
+ output = output + weights["b"]
return output
# Network helpers
def conv_block(
- inp,
- cweight,
- bweight,
- reuse,
- scope,
- use_stride=True,
- activation=tf.nn.leaky_relu,
- pn=False,
- bn=False,
- gn=False,
- ln=False,
- scale=None,
- bias=None,
- class_bias=None,
- use_bias=False,
- downsample=False,
- stop_batch=False,
- use_scale=False,
- extra_bias=False,
- average=False,
- label=None):
- """ Perform, conv, batch norm, nonlinearity, and max pool """
+ inp,
+ cweight,
+ bweight,
+ reuse,
+ scope,
+ use_stride=True,
+ activation=tf.nn.leaky_relu,
+ pn=False,
+ bn=False,
+ gn=False,
+ ln=False,
+ scale=None,
+ bias=None,
+ class_bias=None,
+ use_bias=False,
+ downsample=False,
+ stop_batch=False,
+ use_scale=False,
+ extra_bias=False,
+ average=False,
+ label=None,
+):
+ """Perform, conv, batch norm, nonlinearity, and max pool"""
stride, no_stride = [1, 2, 2, 1], [1, 1, 1, 1]
_, h, w, _ = inp.get_shape()
@@ -695,9 +722,9 @@ def conv_block(
inp = inp + bias
if not use_stride:
- conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME')
+ conv_output = tf.nn.conv2d(inp, cweight, no_stride, "SAME")
else:
- conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME')
+ conv_output = tf.nn.conv2d(inp, cweight, stride, "SAME")
if use_scale:
if label is not None:
@@ -721,8 +748,7 @@ def conv_block(
if pn:
conv_output = pixel_norm(conv_output)
if gn:
- conv_output = group_norm(
- conv_output, scale, bias, stop_batch=stop_batch)
+ conv_output = group_norm(conv_output, scale, bias, stop_batch=stop_batch)
if ln:
conv_output = layer_norm(conv_output, scale, bias)
@@ -732,17 +758,11 @@ def conv_block(
return conv_output
-def conv_block_1d(
- inp,
- cweight,
- bweight,
- reuse,
- scope,
- activation=tf.nn.leaky_relu):
- """ Perform, conv, batch norm, nonlinearity, and max pool """
+def conv_block_1d(inp, cweight, bweight, reuse, scope, activation=tf.nn.leaky_relu):
+ """Perform, conv, batch norm, nonlinearity, and max pool"""
stride = 1
- conv_output = tf.nn.conv1d(inp, cweight, stride, 'SAME') + bweight
+ conv_output = tf.nn.conv1d(inp, cweight, stride, "SAME") + bweight
if activation is not None:
conv_output = activation(conv_output)
@@ -751,21 +771,22 @@ def conv_block_1d(
def conv_block_3d(
- inp,
- cweight,
- bweight,
- reuse,
- scope,
- use_stride=True,
- activation=tf.nn.leaky_relu,
- pn=False,
- bn=False,
- gn=False,
- ln=False,
- scale=None,
- bias=None,
- use_bias=False):
- """ Perform, conv, batch norm, nonlinearity, and max pool """
+ inp,
+ cweight,
+ bweight,
+ reuse,
+ scope,
+ use_stride=True,
+ activation=tf.nn.leaky_relu,
+ pn=False,
+ bn=False,
+ gn=False,
+ ln=False,
+ scale=None,
+ bias=None,
+ use_bias=False,
+):
+ """Perform, conv, batch norm, nonlinearity, and max pool"""
stride, no_stride = [1, 1, 2, 2, 1], [1, 1, 1, 1, 1]
_, d, h, w, _ = inp.get_shape()
@@ -773,9 +794,9 @@ def conv_block_3d(
bweight = 0
if not use_stride:
- conv_output = tf.nn.conv3d(inp, cweight, no_stride, 'SAME') + bweight
+ conv_output = tf.nn.conv3d(inp, cweight, no_stride, "SAME") + bweight
else:
- conv_output = tf.nn.conv3d(inp, cweight, stride, 'SAME') + bweight
+ conv_output = tf.nn.conv3d(inp, cweight, stride, "SAME") + bweight
if activation is not None:
conv_output = activation(conv_output, alpha=0.1)
@@ -841,25 +862,27 @@ def conv_cond_concat(x, y):
y_shapes = tf.shape(y)
return tf.concat(
- [x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]]) / 10.], 3)
+ [x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]]) / 10.0], 3
+ )
def attention(
- inp,
- q,
- q_b,
- k,
- k_b,
- v,
- v_b,
- gamma,
- reuse,
- scope,
- stop_at_grad=False,
- seperate=False,
- scale=False,
- train=False,
- dropout=0.0):
+ inp,
+ q,
+ q_b,
+ k,
+ k_b,
+ v,
+ v_b,
+ gamma,
+ reuse,
+ scope,
+ stop_at_grad=False,
+ seperate=False,
+ scale=False,
+ train=False,
+ dropout=0.0,
+):
conv_q = conv_block(
inp,
q,
@@ -871,7 +894,8 @@ def attention(
use_bias=True,
pn=False,
bn=False,
- gn=False)
+ gn=False,
+ )
conv_k = conv_block(
inp,
k,
@@ -883,7 +907,8 @@ def attention(
use_bias=True,
pn=False,
bn=False,
- gn=False)
+ gn=False,
+ )
conv_v = conv_block(
inp,
@@ -894,7 +919,8 @@ def attention(
use_stride=False,
pn=False,
bn=False,
- gn=False)
+ gn=False,
+ )
c_num = float(conv_q.get_shape().as_list()[-1])
s = tf.matmul(hw_flatten(conv_q), hw_flatten(conv_k), transpose_b=True)
@@ -917,48 +943,33 @@ def attention(
def attention_2d(
- inp,
- q,
- q_b,
- k,
- k_b,
- v,
- v_b,
- reuse,
- scope,
- stop_at_grad=False,
- seperate=False,
- scale=False):
+ inp,
+ q,
+ q_b,
+ k,
+ k_b,
+ v,
+ v_b,
+ reuse,
+ scope,
+ stop_at_grad=False,
+ seperate=False,
+ scale=False,
+):
inp_shape = tf.shape(inp)
inp_compact = tf.reshape(
- inp,
- (inp_shape[0] *
- FLAGS.input_objects *
- inp_shape[1],
- inp.shape[3]))
+ inp, (inp_shape[0] * FLAGS.input_objects * inp_shape[1], inp.shape[3])
+ )
f_q = tf.matmul(inp_compact, q) + q_b
f_k = tf.matmul(inp_compact, k) + k_b
f_v = tf.nn.leaky_relu(tf.matmul(inp_compact, v) + v_b)
- f_q = tf.reshape(f_q,
- (inp_shape[0],
- inp_shape[1],
- inp_shape[2],
- tf.shape(f_q)[-1]))
- f_k = tf.reshape(f_k,
- (inp_shape[0],
- inp_shape[1],
- inp_shape[2],
- tf.shape(f_k)[-1]))
- f_v = tf.reshape(
- f_v,
- (inp_shape[0],
- inp_shape[1],
- inp_shape[2],
- inp_shape[3]))
+ f_q = tf.reshape(f_q, (inp_shape[0], inp_shape[1], inp_shape[2], tf.shape(f_q)[-1]))
+ f_k = tf.reshape(f_k, (inp_shape[0], inp_shape[1], inp_shape[2], tf.shape(f_k)[-1]))
+ f_v = tf.reshape(f_v, (inp_shape[0], inp_shape[1], inp_shape[2], inp_shape[3]))
s = tf.matmul(f_k, f_q, transpose_b=True)
- c_num = (32**0.5)
+ c_num = 32**0.5
if scale:
s = s / c_num
@@ -982,24 +993,21 @@ def batch_norm(inp, scale, bias, eps=0.01):
def normalize(inp, activation, reuse, scope):
- if FLAGS.norm == 'batch_norm':
+ if FLAGS.norm == "batch_norm":
return tf_layers.batch_norm(
- inp,
- activation_fn=activation,
- reuse=reuse,
- scope=scope)
- elif FLAGS.norm == 'layer_norm':
+ inp, activation_fn=activation, reuse=reuse, scope=scope
+ )
+ elif FLAGS.norm == "layer_norm":
return tf_layers.layer_norm(
- inp,
- activation_fn=activation,
- reuse=reuse,
- scope=scope)
- elif FLAGS.norm == 'None':
+ inp, activation_fn=activation, reuse=reuse, scope=scope
+ )
+ elif FLAGS.norm == "None":
if activation is not None:
return activation(inp)
else:
return inp
+
# Loss functions
@@ -1009,11 +1017,11 @@ def mse(pred, label):
return tf.reduce_mean(tf.square(pred - label))
-NO_OPS = 'NO_OPS'
+NO_OPS = "NO_OPS"
def _l2normalize(v, eps=1e-12):
- return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
+ return v / (tf.reduce_sum(v**2) ** 0.5 + eps)
def spectral_normed_weight(w, name, lower_bound=False, iteration=1, fc=False):
@@ -1026,11 +1034,12 @@ def spectral_normed_weight(w, name, lower_bound=False, iteration=1, fc=False):
iteration = FLAGS.spec_iter
sigma_new = FLAGS.spec_norm_val
- u = tf.get_variable(name + "_u",
- [1,
- w_shape[-1]],
- initializer=tf.random_normal_initializer(),
- trainable=False)
+ u = tf.get_variable(
+ name + "_u",
+ [1, w_shape[-1]],
+ initializer=tf.random_normal_initializer(),
+ trainable=False,
+ )
u_hat = u
v_hat = None
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..2f81b07
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,22 @@
+.PHONY: install lint clean
+
+install:
+ @echo "creating virtual environment..."
+ python3 -m venv venv
+ @echo "run: source venv/bin/activate"
+ venv/bin/pip3 install -r scripts/requirements.txt
+
+lint:
+ venv/bin/python3 scripts/auto_fix.py
+
+clean:
+ @echo "🧹 cleaning build artifacts and cache..."
+ find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
+ find . -type f -name "*.pyc" -delete 2>/dev/null || true
+ find . -type f -name "*.pyo" -delete 2>/dev/null || true
+ find . -type f -name "*.pyd" -delete 2>/dev/null || true
+ find . -type f -name ".coverage" -delete 2>/dev/null || true
+ find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
+ find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
+ find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true
+ @echo "✨ cleanup complete!"
diff --git a/README.md b/README.md
index e2ddb0a..705999b 100644
--- a/README.md
+++ b/README.md
@@ -1,19 +1,27 @@
-## resources and experiments on autonomous agents
-
+## the AI toolkit
+
-* **[⬛ ai && ml tl; dr](deep_learning)**
+#### research
+
+* **[⬛ machine learning history](deep_learning)**
* **[⬛ large language models](llms)**
-* **[⬛ agents on blockchains](crypto_agents)**
-* **[⬛ on quantum computing](EBMs)**
- - my adaptation of openai's implicit generation and generalization in energy based models
+
+
+
+#### experiments
+
+* **[⬛ agents on blockchains](crypto_agents)**
+* **[⬛ on quantum computing](EBMs)** (my adaptation of openAI's implicit generation and generalization in energy based
+models)
---
-### cool resources
+### cool discussions
+* **[vub's response to AI 2027 and his take on defense (2025)](https://vitalik.eth.limo/general/2025/07/10/2027.html)**
* **[mr. vp jd vance at the ai action summit in paris (2025)](https://www.youtube.com/watch?v=MnKsxnP2IVk)**
diff --git a/crypto_agents/README.md b/crypto_agents/README.md
index 588b854..dc15702 100644
--- a/crypto_agents/README.md
+++ b/crypto_agents/README.md
@@ -1,8 +1,9 @@
-## crypto agents
+## agents on blockchains
-* **[basic strategy workflow](strategy_workflow)**
+* **[basic strategy workflow (2023)](strategy_workflow)**
+* **[trading on gmx (2023)](trading_on_gmx.md)**
@@ -12,7 +13,6 @@
-
##### projects
* **[ritual.net](https://ritual.net/)**
@@ -27,26 +27,37 @@
##### readings
-* **[the internet's notary public: why verifiability matters, by axal](https://axal.substack.com/p/the-internets-notary-public-why-verifiability)**
-* **[cryptos role in the ai revolution, by pantera](https://panteracapital.com/blockchain-letter/cryptos-role-in-the-ai-revolution/)**
-* **[the promise and challenges of crypto + ai applications, by vub](https://vitalik.eth.limo/general/2024/01/30/cryptoai.html)**
-* **[on training defi agents with markov chains, by bt3gl](https://mirror.xyz/go-outside.eth/DKaWYobU7q3EvZw8x01J7uEmF_E8PfNN27j0VgxQhNQ)**
+* **[microsoft notes on ai agents](https://github.com/microsoft/generative-ai-for-beginners/tree/main/17-ai-agents)**
+* **[the internet's notary public: why verifiability matters, by
+axal](https://axal.substack.com/p/the-internets-notary-public-why-verifiability)**
+* **[cryptos role in the ai revolution, by
+pantera](https://panteracapital.com/blockchain-letter/cryptos-role-in-the-ai-revolution/)**
+* **[the promise and challenges of crypto + ai applications, by
+vub](https://vitalik.eth.limo/general/2024/01/30/cryptoai.html)**
+* **[on training defi agents with markov chains, by
+bt3gl](https://mirror.xyz/go-outside.eth/DKaWYobU7q3EvZw8x01J7uEmF_E8PfNN27j0VgxQhNQ)**
##### metas
* **eth denver 2025 ai + agentic metas**:
- * **[a new dawn for decentralized and the convergence of ai x blockchain, by n. ameline](https://www.youtube.com/watch?v=HQuGtN9zidQ)**
- * **[upgrading agent infra with onchain perpetualaAgents, by kd conway](https://www.youtube.com/watch?v=hDayCeDA5fI)**
- * **[ai meets blockchain: payment, multi-Agent & distributed inference, by o. jaros](https://www.youtube.com/watch?v=aPTosw4hrY0)**
+ * **[a new dawn for decentralized and the convergence of ai x blockchain, by n.
+* **[upgrading agent infra with onchain perpetualaAgents, by kd conway](https://www.youtube.com/watch?v=hDayCeDA5fI)**
+ * **[ai meets blockchain: payment, multi-Agent & distributed inference, by o.
* **[from ai agents to agentic economies, by r. bodkin](https://www.youtube.com/watch?v=Q7eaYJ9aPpI)**
* **[building ai agents and agent economies, by d. minarsch](https://www.youtube.com/watch?v=tDVK2Q5RY0c)**
* **[building ai agents on top of hedera, by j. hall](https://www.youtube.com/watch?v=h8D6vi2m8LQ)**
- * **[identity, privacy, and security in the new age of ai agents, by m. csernai](https://www.youtube.com/watch?v=vMcaot04RQo)**
+ * **[identity, privacy, and security in the new age of ai agents, by m.
* **[verifiable ai agents and world superintelligence, by k. wong](https://www.youtube.com/watch?v=ngkp7HTj_4A)**
* **[how web3 can compete in ai, by g. narula](https://www.youtube.com/watch?v=oLLM1I-3fDU)**
* **[shade agents, by m. lockyer](https://www.youtube.com/watch?v=PEfJnCtrbMU)**
* **[how ai will enable the next generatin of intent, by i. yang](https://www.youtube.com/watch?v=fbc3DpI6jiA)**
* **[2025 is the year of agents, by i. polosukhin](https://www.youtube.com/watch?v=jPyzVNcQMKw)**
+* **[our awesome decentralized AI](https://github.com/shadowy-forest/awesome-decentralized-ai)**
+
+
+##### books
+
+* **[advances in financial machine learning](../books/advances_in_financial_machine_learning.pdf)**
diff --git a/crypto_agents/strategy_workflow/README.md b/crypto_agents/strategy_workflow/README.md
index e08d407..ef9ed9f 100644
--- a/crypto_agents/strategy_workflow/README.md
+++ b/crypto_agents/strategy_workflow/README.md
@@ -3,7 +3,8 @@
-
+
diff --git a/crypto_agents/strategy_workflow/defi_glossary.md b/crypto_agents/strategy_workflow/defi_glossary.md
index 147fbc5..99e53a7 100644
--- a/crypto_agents/strategy_workflow/defi_glossary.md
+++ b/crypto_agents/strategy_workflow/defi_glossary.md
@@ -5,59 +5,81 @@
### A
-- Arbitrage: the simultaneous buying and selling of assets (e.g., cryptocurrencies) in several markets to take advantage of their price discrepancies.
-- Assets under management (AUM): the total market value of the investments that a person or entity manages on behalf of clients.
+- Arbitrage: the simultaneous buying and selling of assets (e.g., cryptocurrencies) in several markets to take advantage
+of their price discrepancies.
+- Assets under management (AUM): the total market value of the investments that a person or entity manages on behalf of
+clients.
### B
-- Backrunning: when an attacker attempts to have a transaction ordered immediately after a certain unconfirmed target transaction.
-- Blocks: a block contains transaction data and the hash of the previous block ensuring immutability in the blockchain network. Each block in a blockchain contains a list of transactions in a particular order. These transactions encode the updates to the blockchain state.
+- Backrunning: when an attacker attempts to have a transaction ordered immediately after a certain unconfirmed target
+transaction.
+- Blocks: a block contains transaction data and the hash of the previous block ensuring immutability in the blockchain
+network. Each block in a blockchain contains a list of transactions in a particular order. These transactions encode the
+updates to the blockchain state.
- Block time: the time interval between blocks being added to the blockchain.
-- Broadcasting: whenever a user interacts with the blockchain, they broadcast a request to include the transaction to the network. This request is public (anyone can listen to it).
-- Builders: actors that take bundles (of pendent transactions from the mempool) and create a final block to send to (multiple) relays (setting themselves afeeRecipient to receive the block’s MEV).
-- Bundles: one or more transactions that are grouped together and executed in the order they are provided. In addition to the searcher's transaction(s), a bundle can also contain other users' pending transactions from the mempool. Bundles can target specific blocks for inclusion as well.
+- Broadcasting: whenever a user interacts with the blockchain, they broadcast a request to include the transaction to
+the network. This request is public (anyone can listen to it).
+- Builders: actors that take bundles (of pendent transactions from the mempool) and create a final block to send to
+(multiple) relays (setting themselves afeeRecipient to receive the block’s MEV).
+- Bundles: one or more transactions that are grouped together and executed in the order they are provided.
### C
-- Central limit order book (CLOB): patient buyers and sellers post limit orders with the price and size that they are willing to buy or sell a given asset. Impatient buyers and sellers place market orders that run through the CLOB until the desired size is reached.
-- Contract address: the address hosting some source code deployed on the Ethereum blockchain, which is executed by a triggering transaction.
-- Crypto copy trading strategy: a trading strategy that uses automation to buy and sell crypto, letting you copy another trader's method.
+- Central limit order book (CLOB): patient buyers and sellers post limit orders with the price and size that they are
+willing to buy or sell a given asset. Impatient buyers and sellers place market orders that run through the CLOB until
+the desired size is reached.
+- Contract address: the address hosting some source code deployed on the Ethereum blockchain, which is executed by a
+triggering transaction.
+- Crypto copy trading strategy: a trading strategy that uses automation to buy and sell crypto, letting you copy another
+trader's method.
### D
- Derivatives: financial contracts that derive their values from underlying assets.
-- Dollar-cost-averaging (DCA) strategy: a one-stop automated trading, based on time intervals, and reducing the influence of market volatility. Parameters for DCA can be: currency, fixed/maximum investment, and amount, investment frequency.
+- Dollar-cost-averaging (DCA) strategy: a one-stop automated trading, based on time intervals, and reducing the
+influence of market volatility. Parameters for DCA can be: currency, fixed/maximum investment, and amount, investment
+frequency.
### E
-- Epoch: in the context of Ethereum's block production, in each slot (every 12 seconds), a validator is randomly chosen to propose the block in that slot. An epoch contains 32 slots.
-- Externally owned account (EOA): an account that is a combination of public address and private key, and that can be used to send and receive Ether to/from another account. An Ethereum address is a 42-character hexadecimal address derived from the last 20 bytes of the public key of the account (with 0x appended in front).
+- Epoch: in the context of Ethereum's block production, in each slot (every 12 seconds), a validator is randomly chosen
+to propose the block in that slot. An epoch contains 32 slots.
+- Externally owned account (EOA): an account that is a combination of public address and private key, and that can be
+used to send and receive Ether to/from another account. An Ethereum address is a 42-character hexadecimal address
+derived from the last 20 bytes of the public key of the account (with 0x appended in front).
### F
-- Frontrunning: the process by which an adversary observes transactions on the network layer and acts on this information to obtain profit.
+- Frontrunning: the process by which an adversary observes transactions on the network layer and acts on this
+information to obtain profit.
- Fully diluted valuations (FDV): the total number of tokens multiplied by the current price of a single token.
-- Futures: contracts used as proxy tools to speculate on the future prices of crypto assets or to hedge against their price changes.
-- Future grid trading bots: bots that automate futures trading activities based on grid trading strategies (a set of orders is placed both above and below a specific reference market price for the asset).
+- Futures: contracts used as proxy tools to speculate on the future prices of crypto assets or to hedge against their
+price changes.
+- Future grid trading bots: bots that automate futures trading activities based on grid trading strategies (a set of
+orders is placed both above and below a specific reference market price for the asset).
### G
-- Gas price: used somewhat like a bid, indicating an amount the user is willing to pay (per unit of execution) to have their transaction processed.
-- Gwei: a small unit of the Ethereum network's Ether (ETH) cryptocurrency. A gwei or gigawei is defined as 1,000,000,000 wei, the smallest base unit of Ether. Conversely, 1 ETH represents 1 billion gwei.
-- Grid trading strategy: a strategy that involves placing orders above and below a set price, using a price grid of orders (which shows orders at incrementally increasing and decreasing prices). Grid trading is based on the overarching goal of buying low and selling high.
+- Gas price: used somewhat like a bid, indicating an amount the user is willing to pay (per unit of execution) to have
+their transaction processed.
+- Gwei: a small unit of the Ethereum network's Ether (ETH) cryptocurrency.
+- Grid trading strategy: a strategy that involves placing orders above and below a set price, using a price grid of
+orders (which shows orders at incrementally increasing and decreasing prices). Grid trading is based on the overarching
+goal of buying low and selling high.
@@ -75,10 +97,12 @@
### L
-- Limit orders: when one longs or shorts a contract, several execution options can be placed (usually with a fee difference). Limit orders that are set at a specific price to be traded, and there is no guarantee that the trade will be executed (see market orders and stop-loss orders).
-- Liquidity pools: a collection of crypto assets that can be used for decentralized trading. They are essential for automated market makers (AMM), borrow-lend protocols, yield farming, synthetic assets, on-chain insurance, blockchain gaming, etc.
+- Limit orders: when one longs or shorts a contract, several execution options can be placed (usually with a fee
+difference). Limit orders that are set at a specific price to be traded, and there is no guarantee that the trade will
+be executed (see market orders and stop-loss orders).
+- Liquidity pools: a collection of crypto assets that can be used for decentralized trading.
- Liquidation threshold: the percentage at which a collateral value is counted towards the borrowing capacity.
-- Liquidation: when the value of a borrowed asset exceeds the collateral. Anyone can liquidate the collateral and collect the liquidation fee for themselves.
+- Liquidation: when the value of a borrowed asset exceeds the collateral.
- Long: traders maintain long positions, which means that they expect the price of a coin to rise in the future.
@@ -86,23 +110,32 @@
### M
- Fully diluted market capitalization: the total token supply, multiplied by the price of a single token.
-- Circulating supply market capitalization: the number of tokens that are available in the market, multiplied by the price of a single token.
+- Circulating supply market capitalization: the number of tokens that are available in the market, multiplied by the
+price of a single token.
- Margin trading: buying or sell assets with leverage.
- Marginal seller: a type of seller who is willing first to leave the market if the prices are lower.
- Market orders: Market orders are executed immediately at the asset's market price (see limit orders).
-- Mean reversion strategy: a trading range (or mean reversion) strategy is based on the concept that an asset's high and low prices are a temporary effect that reverts to their mean value (average value).
-- Mempool: a cryptocurrency node’s mechanism for storing information on unconfirmed transactions.
-- Merkle tree: a type of binary tree, composed of: 1) a set of notes with a large number of leaf nodes at the bottom, containing the underlying data, 2) a set of intermediate nodes where each node is the hash of its two children, and 3) a single root node, also formed from the hash of its two children, representing the top of the tree.
-- Minting: the process of validating information, creating a new block, and recording that information into the blockchain.
+- Mean reversion strategy: a trading range (or mean reversion) strategy is based on the concept that an asset's high and
+low prices are a temporary effect that reverts to their mean value (average value).
+- Mempool: a cryptocurrency node’s mechanism for storing information on unconfirmed transactions.
+- Merkle tree: a type of binary tree, composed of: 1) a set of notes with a large number of leaf nodes at the bottom,
+containing the underlying data, 2) a set of intermediate nodes where each node is the hash of its two children, and 3) a
+single root node, also formed from the hash of its two children, representing the top of the tree.
+- Minting: the process of validating information, creating a new block, and recording that information into the
+blockchain.
### P
-- Perpetual contract: a contract without an expiration date, where interest rates can be calculated by methods such as Time-Weighted-Average-Price (TWAP).
-- Priority gas auctions: bots compete against each other by binding up transaction fees (gas) to extract revenue from arbitrage opportunities, driving up user fees.
-- Private key: a secret number enabling a blockchain user to prove ownership on an account or contract, via a digital signature.
-- Publick key: a number generated by a one-way (hash) function from the private key, used to verify a digital signature made with the matching private key.
+- Perpetual contract: a contract without an expiration date, where interest rates can be calculated by methods such as
+Time-Weighted-Average-Price (TWAP).
+- Priority gas auctions: bots compete against each other by binding up transaction fees (gas) to extract revenue from
+arbitrage opportunities, driving up user fees.
+- Private key: a secret number enabling a blockchain user to prove ownership on an account or contract, via a digital
+signature.
+- Publick key: a number generated by a one-way (hash) function from the private key, used to verify a digital signature
+made with the matching private key.
- Provider: an entity that provides an abstraction for a connection to the blockchain network.
- POFPs: private order flow protocols.
@@ -110,8 +143,9 @@
### O
-- Order flow: in the context of Ethereum and EVM-based blockchains, an order is anything that allows changing the state of the blockchain.
-- Open interest: total number of futures contracts held by market participants at the end of the trading day. Used as an indicator to determine market sentiment and the strength behind price trends.
+- Order flow: in the context of Ethereum and EVM-based blockchains, an order is anything that allows changing the state
+of the blockchain.
+- Open interest: total number of futures contracts held by market participants at the end of the trading day.
@@ -123,37 +157,48 @@
### S
-- Slots: in the context of Ethereum's block production, a slot is a time period of 12 seconds in which a randomly chosen validator has time to propose a block.
-- Smart contracts: a computer protocol intended to enforce a contract on the blockchain without third parties. They are reliant upon code (the functions) and data (the state), and they can trigger specific actions, such as transferring tokens from A to B.
-- Sandwich attack: when slippage value is not set, this attack can happen by an actor bumping the price of an asset to an unfavorable level, executing the trade, and then returning the asset to the original price.
+- Slots: in the context of Ethereum's block production, a slot is a time period of 12 seconds in which a randomly chosen
+validator has time to propose a block.
+- Smart contracts: a computer protocol intended to enforce a contract on the blockchain without third parties.
+- Sandwich attack: when slippage value is not set, this attack can happen by an actor bumping the price of an asset to
+an unfavorable level, executing the trade, and then returning the asset to the original price.
- Slippage: delta in pricing between the time of order and when the order is executed.
- Short: traders maintain short positions, which means they expect the price of a coin to drop in the future.
-- Short squeeze: occurs when a heavily shorted stock experiences an increase in price for some unexpected reason. This situation prompts short sellers to scramble to buy the stock to cover their positions and cap their mounting losses.
+- Short squeeze: occurs when a heavily shorted stock experiences an increase in price for some unexpected reason.
- Spot trading: buy or selling assets for immediate delivery.
-- Statistical trading: is the class of strategies that aim to generate profitable situations, stemming from pricing inefficiencies among financial markets. Statistical arbitrage is a strategy to obtain profit by applying past statistics.
-- Stop-loss orders: this type of order execution places a market/limit order to close a position to restrict an investor's loss on a crypto asset.
+- Statistical trading: is the class of strategies that aim to generate profitable situations, stemming from pricing
+inefficiencies among financial markets. Statistical arbitrage is a strategy to obtain profit by applying past
+statistics.
+- Stop-loss orders: this type of order execution places a market/limit order to close a position to restrict an
+investor's loss on a crypto asset.
### T
-- otal value locked (TVL): the value of all tokens locked in various DeFi protocols such as lending platforms, DEXes, or derivatives protocols.
+- otal value locked (TVL): the value of all tokens locked in various DeFi protocols such as lending platforms, DEXes, or
+derivatives protocols.
- Тrading volume: the total amount of traded cryptocurrency (equivalent to US dollars) during a given timeframe.
-- Transaction: on EVM-based blockchains, there the two types of transactions are normal transactions and contract interactions.
+- Transaction: on EVM-based blockchains, there the two types of transactions are normal transactions and contract
+interactions.
- Transaction hash: a unique 66-character identifier generated with each new transaction.
-- Transaction ordering: blockchains usually have loose requirements for how transactions are ordered within a block, allowing attacks that benefit from certain ordering.
-- Time-weighted average price strategy: TWAP strategy breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using evenly divided time slots between a start and end time.
+- Transaction ordering: blockchains usually have loose requirements for how transactions are ordered within a block,
+allowing attacks that benefit from certain ordering.
+- Time-weighted average price strategy: TWAP strategy breaks up a large order and releases dynamically determined
+smaller chunks of the order to the market, using evenly divided time slots between a start and end time.
### V
-- Validation: a mathematical proof that the state change in the blockchain is consistent. To be included into a block in the blockchain, a list of transactions needs to be validated.
+- Validation: a mathematical proof that the state change in the blockchain is consistent.
- VTRPs: validator transaction Reordering protocols.
-- Volume-weighted average price strategy: VWAP breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using historical volume profiles.
+- Volume-weighted average price strategy: VWAP breaks up a large order and releases dynamically determined smaller
+chunks of the order to the market, using historical volume profiles.
### W
-- Whales: individuals or institutions who hold large amounts of coins of a certain cryptocurrency, and can become powerful enough to manipulate the valuation.
+- Whales: individuals or institutions who hold large amounts of coins of a certain cryptocurrency, and can become
+powerful enough to manipulate the valuation.
diff --git a/crypto_agents/strategy_workflow/optimization.md b/crypto_agents/strategy_workflow/optimization.md
index cbb5e96..d5ed6ce 100644
--- a/crypto_agents/strategy_workflow/optimization.md
+++ b/crypto_agents/strategy_workflow/optimization.md
@@ -2,5 +2,6 @@
-* perform a search, for example grid search, over possible values of strategy parameters like thresholds or coefficients (using the simulator and a set of historical data)
+* perform a search, for example grid search, over possible values of strategy parameters like thresholds or coefficients
+(using the simulator and a set of historical data)
* overfitting to historical data is a big risk (be careful with validation and test sets).
diff --git a/crypto_agents/strategy_workflow/paper_trading.md b/crypto_agents/strategy_workflow/paper_trading.md
index 4764e97..4ae764f 100644
--- a/crypto_agents/strategy_workflow/paper_trading.md
+++ b/crypto_agents/strategy_workflow/paper_trading.md
@@ -2,4 +2,5 @@
-* before the strategy goes live, simulation is done on new market data, in real-time (paper trading), which prevents overfitting
+* before the strategy goes live, simulation is done on new market data, in real-time (paper trading), which prevents
+overfitting
diff --git a/crypto_agents/strategy_workflow/policy.md b/crypto_agents/strategy_workflow/policy.md
index d9a1d07..6a8fc31 100644
--- a/crypto_agents/strategy_workflow/policy.md
+++ b/crypto_agents/strategy_workflow/policy.md
@@ -2,4 +2,5 @@
-* come with a rule-based policy that determines what actions to take based on the current state of the market and the outpus of supervised models.
+* come with a rule-based policy that determines what actions to take based on the current state of the market and the
+outputs of supervised models.
diff --git a/crypto_agents/strategy_workflow/strategy_metrics.md b/crypto_agents/strategy_workflow/strategy_metrics.md
index d7b86e6..39184b7 100644
--- a/crypto_agents/strategy_workflow/strategy_metrics.md
+++ b/crypto_agents/strategy_workflow/strategy_metrics.md
@@ -2,8 +2,12 @@
-* **net pnl (net profit and loss):** how much money an algorithm makes (positive) or loses (negative) over some period, minus trading fees
+* **net pnl (net profit and loss):** how much money an algorithm makes (positive) or loses (negative) over some period,
+minus trading fees
* **alpha nad beta**
-* **shape ratio:** the excess return per unit of risk you are taking (return on capital over the standard deviation adjusted for risk; the higher the better).
-* **maximum drawdown:** maximum difference between a local maximum and a subsequent local minimum as an another measure of risk.
-* **value at risk (var):** how much capital you may lose over a given time frame with some probability, assumong normal market conditions.
+* **shape ratio:** the excess return per unit of risk you are taking (return on capital over the standard deviation
+adjusted for risk; the higher the better).
+* **maximum drawdown:** maximum difference between a local maximum and a subsequent local minimum as an another measure
+of risk.
+* **value at risk (var):** how much capital you may lose over a given time frame with some probability, assumong normal
+market conditions.
diff --git a/crypto_agents/strategy_workflow/supervised_learning.md b/crypto_agents/strategy_workflow/supervised_learning.md
index e9cc446..9cee0c0 100644
--- a/crypto_agents/strategy_workflow/supervised_learning.md
+++ b/crypto_agents/strategy_workflow/supervised_learning.md
@@ -2,4 +2,5 @@
-* train one or more supervised learning models to predict quantities of interest that are necessary for the strategy work, for example, price prediction, quantity prediction, etc.
+* train one or more supervised learning models to predict quantities of interest that are necessary for the strategy
+work, for example, price prediction, quantity prediction, etc.
diff --git a/crypto_agents/trading_on_gmx.md b/crypto_agents/trading_on_gmx.md
index d79d98e..0fcb620 100644
--- a/crypto_agents/trading_on_gmx.md
+++ b/crypto_agents/trading_on_gmx.md
@@ -13,8 +13,10 @@
-
-
+
+
@@ -25,10 +27,12 @@
* the order book is made of two sides, asks (sell, offers) and bids (buy).
-* the best ask (the lowest price someone is willing to sell ) > the best bid (the highest price someone is willing to buy).
+* the best ask (the lowest price someone is willing to sell ) > the best bid (the highest price someone is willing to
+buy).
* the difference between the best ask and the best bid is called spread.
* **market order**: best price possible, right now. it takes liquidity from the market and usually has higher fees.
-* **limit order (passive order)**: specify the price and qty you are willing to buy or sell at, and then wait for the match.
+* **limit order (passive order)**: specify the price and qty you are willing to buy or sell at, and then wait for the
+match.
* **stop orders**: allow you to set a maximum price for your market orders.
diff --git a/deep_learning/README.md b/deep_learning/README.md
index 5cefb7f..8369a09 100644
--- a/deep_learning/README.md
+++ b/deep_learning/README.md
@@ -1,4 +1,4 @@
-## ai agents
+## some machine learning history
@@ -13,8 +13,6 @@
-* **[cursor ai editor](https://www.cursor.com/)**
-* **[microsoft notes on ai agents](https://github.com/microsoft/generative-ai-for-beginners/tree/main/17-ai-agents)**
* **[google's jax (composable transformations of numpy programs)](https://github.com/google/jax)**
* **[machine learning engineering open book](https://github.com/stas00/ml-engineering)**
-* **[advances in financial machine learning](books/advances_in_financial_machine_learning.pdf)**
+
diff --git a/deep_learning/deep_learning.md b/deep_learning/deep_learning.md
index af5cbd2..ed39301 100644
--- a/deep_learning/deep_learning.md
+++ b/deep_learning/deep_learning.md
@@ -1,4 +1,4 @@
-## deep learning
+## deep learning
@@ -10,9 +10,11 @@
* **[2013: atari with deep reinforcement learning](https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial)**
* **[2014: seq2seq](https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt)**
-* **[2014: adam optmizer](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L281)**
+* **[2014: adam
+optmizer](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L281)**
* **[2015: gans](https://www.tensorflow.org/tutorials/generative/dcgan)**
-* **[2015: resnets](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/applications/resnet.py)**
+* **[2015:
+resnets](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/applications/resnet.py)**
* **[2017: transformers](https://github.com/huggingface/transformers)**
* **[2018: bert](https://arxiv.org/abs/1810.04805)**
@@ -24,30 +26,38 @@
-* a map consists of a set of states, a set of actions, a transition function that describes the probability of moving rom one state to another after taking an action, and a reward function that assigns a numerical reward to each state-action pair
+* a map consists of a set of states, a set of actions, a transition function that describes the probability of moving
+rom one state to another after taking an action, and a reward function that assigns a numerical reward to each
+state-action pair
* the goal of a map is to maximize its expected cumulative reward over a sequence of actions, called a policy.
-* a policy is a function that maps each state to a probability distribution over actions. The optimal policy is the one that maximizes the expected cumulative rewards.
+* a policy is a function that maps each state to a probability distribution over actions.
-* the problem of reinforcement learning can be formalized using ideas from dynamical systems theory, specifically, as the optimal control of incompletely-known Markov decision processes.
+* the problem of reinforcement learning can be formalized using ideas from dynamical systems theory, specifically, as
+the optimal control of incompletely-known Markov decision processes.
-* as opposed to supervised learning, an agent must be able to learn from its own experience. and as oppose to unsupervised learning because, reinforcement learning is trying to maximize a reward signal instead of trying to find hidden structure.
+* as opposed to supervised learning, an agent must be able to learn from its own experience.
-* the agent has to exploit what it has already experienced in order to obtain reward, but it also has to explore in order to make better action selections in the future. on a stochastic task, each action must be tried many times to gain a reliable estimate of its expected reward.
+* the agent has to exploit what it has already experienced in order to obtain reward, but it also has to explore in
+order to make better action selections in the future. on a stochastic task, each action must be tried many times to gain
+a reliable estimate of its expected reward.
-* beyond the agent and the environment, one can identify four main subelements of a reinforcement learning system: a policy, a reward signal, a value function, and, optionally, a model of the environment.
+* beyond the agent and the environment, one can identify four main subelements of a reinforcement learning system: a
+policy, a reward signal, a value function, and, optionally, a model of the environment.
-* traditional reinforcement learning problems can be formulated as a markov decision process (MDP):
+* traditional reinforcement learning problems can be formulated as a markov decision process (MDP):
* we have an agent acting in an environment
- * each step *t* the agent receives as the input the current state S_t, takes an action A_t, and receives a reward R_{t+1} and the next state S_{t+1}
+* each step *t* the agent receives as the input the current state S_t, takes an action A_t, and receives a reward
+R_{t+1} and the next state S_{t+1}
* the agent choose the action based on some policy pi: A_t = pi(S_t)
- * it's our goal to find a policy that maximizes the cumulative reward Sum R_t over some finite or infinite time horizon
+* it's our goal to find a policy that maximizes the cumulative reward Sum R_t over some finite or infinite time horizon
-
+
@@ -55,7 +65,7 @@
-* agent is the trading agent (e.g. the human trader who opens the gui of an exchange and makes trading decision based on the current state of the exchange and their account)
+* agent is the trading agent (e.g.
@@ -65,7 +75,8 @@
* the exchange and other agents are the environment, and they are not something we can control
* by putting other agents together into some big complex environment, we lose the ability to explicitly model them
-* if we try to reverse-engineer the algorithms and strategies that other traders are running, put us into a multi-agent reinforcement learning (MARL) problem setting
+* if we try to reverse-engineer the algorithms and strategies that other traders are running, put us into a multi-agent
+reinforcement learning (MARL) problem setting
@@ -73,11 +84,12 @@
-* in the case of trading on an exchange, we don't observe the complete state of the environment (e.g. other agents), so we are dealing with a partially observable markov decision process (pomdp).
+* in the case of trading on an exchange, we don't observe the complete state of the environment (e.g.
* what the agents observe is not the actual state S_t of the environment, but some derivation of that.
* we can call the observation X_t, which is calculated using some function of the full state X_t ~ O(S_t)
* the observation at each timestep t is simply the history of all exchange events received up to time t.
-* this event history can be used to build up the current exchange state, however, in order for our agent to make decisions, extra info such as account balance and open limit orders need to be included.
+* this event history can be used to build up the current exchange state, however, in order for our agent to make
+decisions, extra info such as account balance and open limit orders need to be included.
@@ -85,8 +97,9 @@
-* hft techniques: decisions are based almost entirely on market microstructure signals. decisions are made on nanoseconds timescales and trading strategies use dedicated connections to exchanges and extremly fast but simple algorithms running fpga hardware.
-* neural networks are slow, they can't make predictions on nanoseconds time scales, so they can't compete with the speed of hft algorithms.
+* hft techniques: decisions are based almost entirely on market microstructure signals.
+* neural networks are slow, they can't make predictions on nanoseconds time scales, so they can't compete with the speed
+of hft algorithms.
* guess: the optimal time scale is between a few milliseconds and a few minutes.
* can deep rl algorithms pick up hidden patterns?
@@ -96,9 +109,11 @@
-* the simplest approach has 3 actions: buy, hold, and sell. this works but limits us to placing market orders and to invest a deterministic amount of money at each step.
-* in the next level we would let our agents learn how much money to invest, based on the uncertainty of our model, putting us into a continuous action space.
-* in the next level, we would introduce limit orders, and the agent needs to decide the level (price) and wuantity of the order, and be able to cancel orders that have not been yet matched.
+* the simplest approach has 3 actions: buy, hold, and sell.
+* in the next level we would let our agents learn how much money to invest, based on the uncertainty of our model,
+putting us into a continuous action space.
+* in the next level, we would introduce limit orders, and the agent needs to decide the level (price) and wuantity of
+the order, and be able to cancel orders that have not been yet matched.
@@ -106,18 +121,21 @@
-* there are several possible reward functions, an obvious would realized PnL (profit and loss). the agent receives a reward whenever it closes a position.
+* there are several possible reward functions, an obvious would realized PnL (profit and loss).
* the net profit is either negative or positive, and this is the reward signal.
-* as the agent maximize the total cumulative reward, it learns to trade profitably. the reward function leads to the optimal policy in the limit.
-* however, buy and sell actions are rare compared to doing nothing; the agent needs to learn without receiving frequent feedback.
-* an alternative is unrealized pnl, which the net profit the agent would get if it were to close all of its positions immediately.
-* because the unrealized pnl may change at each time step, it gives the agent more frequent feedback signals. however the direct feedback may bias the agent towards short-term actions.
+* as the agent maximize the total cumulative reward, it learns to trade profitably.
+* however, buy and sell actions are rare compared to doing nothing; the agent needs to learn without receiving frequent
+feedback.
+* an alternative is unrealized pnl, which the net profit the agent would get if it were to close all of its positions
+immediately.
+* because the unrealized pnl may change at each time step, it gives the agent more frequent feedback signals.
* both naively optimize for profit, but a trader may want to minimize risk (lower volatility)
* using the sharpe ration is one simple way to take risk into account. other way is maximum drawdown.
-
+
@@ -134,11 +152,14 @@
-* we need separate backtesting and parameter optimization steps because it was difficult for our strategies to take into account environmental factors: order book liquidity, fee structures, latencies.
-* getting around environmental limitations is part of the opimization process. if we simulate the latency in the reinforcement learning environment, and this results in the agent making a mistake, the agent will get a negative rewards, forcing it to learn to work around the latencies.
-* by learning a model of the environment and performing rollouts using techniques like a monte carlo tree search (mcts), we could take into account potential reactions of the market (other agents)
+* we need separate backtesting and parameter optimization steps because it was difficult for our strategies to take into
+account environmental factors: order book liquidity, fee structures, latencies.
+* getting around environmental limitations is part of the opimization process.
+* by learning a model of the environment and performing rollouts using techniques like a monte carlo tree search (mcts),
+we could take into account potential reactions of the market (other agents)
* by being smart about the data we collect from the live environment, we can continously improve our model
-* do we act optimally in the live environment to generate profits, or do we act suboptimally to gather interesting information that we can use to improve the model of our environment and other agents?
+* do we act optimally in the live environment to generate profits, or do we act suboptimally to gather interesting
+information that we can use to improve the model of our environment and other agents?
@@ -147,7 +168,9 @@
* some strategy may work better in a bearish environment but lose money in a bullish environment.
-* because rl agents are learning powerful policies parameterized by NN, they can alos learn to adapt to market conditions by seeing them in historical data, given that they are trained over long time horizon and have sufficient memory.
+* because rl agents are learning powerful policies parameterized by NN, they can alos learn to adapt to market
+conditions by seeing them in historical data, given that they are trained over long time horizon and have sufficient
+memory.
@@ -156,9 +179,13 @@
* the trading environment is a multiplayer game with thousands of agents acting simultaneously
-* understanding how to build models of other agents is only one possible we can, we can choose perfom actions in a live environment with the goal of maximizing the information grain with respect to kind policies the other agents may be following
+* understanding how to build models of other agents is only one possible we can, we can choose perfom actions in a live
+environment with the goal of maximizing the information grain with respect to kind policies the other agents may be
+following
* trading agents receive sparse rewards from the market. naively applying reward-hungry rl algorithms will fail.
* this opens up the possibility for new algorithms and techniques, that can efficiently deal with sparse rewards.
-* many of today's standard algorithms, such as dqn or a3c, use a very naive approach exploration - basically adding random noise to the policy. however, in the trading case, most states in the environment are bad, and there are only a few good ones. a naive random approach to exploration will almost never stumble upon good state-actions paris.
-* the trading environment is inherently nonstationary. market conditions change and other agent join, leave, and constantly change their strategies.
+* many of today's standard algorithms, such as dqn or a3c, use a very naive approach exploration - basically adding
+random noise to the policy. however, in the trading case, most states in the environment are bad, and there are only a
+few good ones. a naive random approach to exploration will almost never stumble upon good state-actions paris.
+* the trading environment is inherently nonstationary.
* can we train an agent that can transit from bear to bull and then back to bear, without needing to be re-trained?
diff --git a/deep_learning/reinforcement_learning.md b/deep_learning/reinforcement_learning.md
index b2d8472..eb9d3b2 100644
--- a/deep_learning/reinforcement_learning.md
+++ b/deep_learning/reinforcement_learning.md
@@ -6,8 +6,10 @@
-* reinforcement learning is learning what to do (how to map situations to actions) so as to maximize a numerical reward signal
-* an autonomous agent is a software program or system that can operate independently and make decisions on its own, without direct intervention from a human
+* reinforcement learning is learning what to do (how to map situations to actions) so as to maximize a numerical reward
+signal
+* an autonomous agent is a software program or system that can operate independently and make decisions on its own,
+without direct intervention from a human
@@ -17,10 +19,13 @@
-* we formalize the problem of reinforcement using ideas from dynamical system theory, as the optimal control of incompletely-known Markov decision processes.
-* a learning agent must be able to sense the state of its environment to some extent and must be able to take actions that affect the state.
+* we formalize the problem of reinforcement using ideas from dynamical system theory, as the optimal control of
+incompletely-known Markov decision processes.
+* a learning agent must be able to sense the state of its environment to some extent and must be able to take actions
+that affect the state.
* markov decision processes are intented to include just these three aspects, sensation, action, and goal.
-* the agent has to exploit what it has already experienced in order to obtain reward, but it has also to explore in order to make better action selections in the future.
+* the agent has to exploit what it has already experienced in order to obtain reward, but it has also to explore in
+order to make better action selections in the future.
* on a stochastic tasks, each action must be tried many times to gain a reliable estimate of its expected reward.
@@ -31,12 +36,17 @@
-* beyond the agent and the environment, 4 more elements belong to a reinforcement learning system: a policy, a reward signal, a value funtion, and a model of the environmnet.
-* a policy defines the learning agent's way of behacing at a given time. It's a mapping from perceiv ed states of the environment to actions to be taken when in those states. in general, policies may be stochastics (specifying probabilities for each action).
-* a reward signal defines the goal of a reinforcement learning problem: on each time step, the environment sends to the reinforcement learning agent a single number called the reward. the agent's sole objective is to maximize the total reward over the run.
-* a value function specifies what is good in the long run, the valye of a state in the total amount of reward an agent can expect to accumulate over the future, starting from that state
+* beyond the agent and the environment, 4 more elements belong to a reinforcement learning system: a policy, a reward
+signal, a value funtion, and a model of the environmnet.
+* a policy defines the learning agent's way of behacing at a given time.
+* a reward signal defines the goal of a reinforcement learning problem: on each time step, the environment sends to the
+reinforcement learning agent a single number called the reward. the agent's sole objective is to maximize the total
+reward over the run.
+* a value function specifies what is good in the long run, the valye of a state in the total amount of reward an agent
+can expect to accumulate over the future, starting from that state
* a model of the environment.
-* the most important feature distinguishing reinforcement learning from other types of learning is that it uses training information that evaluates the actions taken rather than instructs by giving correct actions.
+* the most important feature distinguishing reinforcement learning from other types of learning is that it uses training
+information that evaluates the actions taken rather than instructs by giving correct actions.
@@ -47,7 +57,8 @@
* the problem involves evaluating feedbacks and choosing different actions in different situations.
-* mdps are a classical formalization of sequential decision making, where actions influence not just immediate rewards, but also subsequent situations.
+* mdps are a classical formalization of sequential decision making, where actions influence not just immediate rewards,
+but also subsequent situations.
* mdps involve delayed reward and the need to trade off immediate and delayed reward.
@@ -57,34 +68,42 @@
* mdps are meant to be a straightfoward framing of the problem of learning from interaction to achieve a goal.
* the learner and the decision makers is called the agent.
* the thing it interacts with, comprimising everything outside the agent, is called the environment.
-* the environment gives rise to rewards, numerical values that the agent seeks to maximize over time through its choice of actions.
+* the environment gives rise to rewards, numerical values that the agent seeks to maximize over time through its choice
+of actions.
-
+
* the agent and the environment interact at each of a sequence of discrete steps, t = 0, 1, 2, 3...
* at each time step t, the agent receives some representation of the environments state St
* on that basis, the agent selects an action At
-* one step later, in part of a consequence of its action, the agent receives a numerical rewards and finds itself in a new state.
+* one step later, in part of a consequence of its action, the agent receives a numerical rewards and finds itself in a
+new state.
* the mdp and the agent together give rise to a sequence (trajectory)
-* in a finite mdp, the set of states, actions, and rewards all have a finite number of elements. in this case, the random variables R and S have well defined discrete probability distributions dependent only on the proceding state and action.
+* in a finite mdp, the set of states, actions, and rewards all have a finite number of elements.
* in a markov decision process, the probabilities given by p completely characterize the environment's dynamics.
-* the state must include information about all aspects of the past agent-environment interaction that make a differnce for the future.
-* anything that cannot be changed arbitrarily by the agent is considered to be outside of it and thus part of its environment.
+* the state must include information about all aspects of the past agent-environment interaction that make a differnce
+for the future.
+* anything that cannot be changed arbitrarily by the agent is considered to be outside of it and thus part of its
+environment.
##### goals and rewards
-* each episode ends in a special state called the terminal state, followed by a reset to a standard starting state or to a sample from a standard distribution of starting states.
-* almost all reinforcement learning algorithms involve estimating value functions—functions of states (or of state–action pairs) that estimate how good it is for the agent to be in a given state (or how good it is to perform a given action in a given state).
-* the Bellman equation averages over all the possibilities, weighting each by its probability of occurring. tt states that the value of the start state must equal the
+* each episode ends in a special state called the terminal state, followed by a reset to a standard starting state or to
+a sample from a standard distribution of starting states.
+* almost all reinforcement learning algorithms involve estimating value functions—functions of states (or of
+state–action pairs) that estimate how good it is for the agent to be in a given state (or how good it is to perform a
+given action in a given state).
+* the Bellman equation averages over all the possibilities, weighting each by its probability of occurring.
(discounted) value of the expected next state, plus the reward expected along the way.
-* solving a reinforcement learning task means finding a policy that achieves a lot of reward over the long run.
+* solving a reinforcement learning task means finding a policy that achieves a lot of reward over the long run.
@@ -92,20 +111,28 @@
### dynamic programming
-* collection of algorithms that can be used to compute optimal policies given a perfect model of the environment as a mdp.
-* a common way of obtaining approximate solutions for tasks with continuous states and actions is to quantize the state and action spaces and then apply finite-state DP methods.
+* collection of algorithms that can be used to compute optimal policies given a perfect model of the environment as a
+mdp.
+* a common way of obtaining approximate solutions for tasks with continuous states and actions is to quantize the state
+and action spaces and then apply finite-state DP methods.
* the reason for computing the value function for a policy is to help find better policies.
-* asynchronous DP algorithms are in-place iterative DP algorithms that are not organized in terms of systematic sweeps of the state set. these algorithms update the values of states in any order whatsoever, using whatever values of other states happen to be available. the values of some states may be updated several times before the values of others ar
-* policy evaluation refers to the (typi- cally) iterative computation of the value function for a given policy.
+* asynchronous DP algorithms are in-place iterative DP algorithms that are not organized in terms of systematic sweeps
+of the state set. these algorithms update the values of states in any order whatsoever, using whatever values of other
+states happen to be available. the values of some states may be updated several times before the values of others ar
+* policy evaluation refers to the (typi- cally) iterative computation of the value function for a given policy.
* policy improvement refers to the computation of an improved policy given the value function for that policy.
##### generalized policy interaction
-* policy iteration consists of two simultaneous, interacting processes, one making the value function consistent with the current policy (policy evaluation), and the other making the policy greedy with respect to the current value function (policy improvement).
-* generalized policy iteration (GPI) refers to the general idea of letting policy-evaluation and policy-improvement processes interact, independent of the granularity and other details of the two processes.
-* DP is sometimes thought to be of limited applicability because of the curse of dimen- sionality, the fact that the number of states often grows exponentially with the number of state variables
+* policy iteration consists of two simultaneous, interacting processes, one making the value function consistent with
+the current policy (policy evaluation), and the other making the policy greedy with respect to the current value
+function (policy improvement).
+* generalized policy iteration (GPI) refers to the general idea of letting policy-evaluation and policy-improvement
+processes interact, independent of the granularity and other details of the two processes.
+* DP is sometimes thought to be of limited applicability because of the curse of dimen- sionality, the fact that the
+number of states often grows exponentially with the number of state variables
@@ -116,4 +143,4 @@
* **[gymnasium api](https://gymnasium.farama.org/)**
-* **[reinforcement learning with unsupervised auxiliary tasks, by jaderberg et al.](https://arxiv.org/abs/1611.05397)**
\ No newline at end of file
+* **[reinforcement learning with unsupervised auxiliary tasks, by jaderberg et al.](https://arxiv.org/abs/1611.05397)**
diff --git a/llms/README.md b/llms/README.md
index 9d49c2b..7654b9f 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -2,9 +2,9 @@
-* **[opeanai](opeanai)**
+* **[google's gemini](gemini)**
+* **[openAI](openAI)**
* **[claude](claude)**
-* **[gemini](gemini)**
* **[deepseek](deepseek)**
@@ -17,7 +17,7 @@
#### articles
-* **[people cannot distinguish gpt-4 from a human in a turing test, by c. jones et al (2024)](https://arxiv.org/pdf/2405.08007)**
+* **[people cannot distinguish gpt-4 from a human in a turing test, by c.
diff --git a/llms/claude/README.md b/llms/claude/README.md
index 60e2e32..2e69c19 100644
--- a/llms/claude/README.md
+++ b/llms/claude/README.md
@@ -2,6 +2,12 @@
+
+
+
+
+---
+
### cool resources
diff --git a/llms/deepseek/README.md b/llms/deepseek/README.md
index 4d1d59c..8f73eac 100644
--- a/llms/deepseek/README.md
+++ b/llms/deepseek/README.md
@@ -3,6 +3,8 @@
-
-
+
+
diff --git a/llms/gemini/README.md b/llms/gemini/README.md
index f98041f..54d5934 100644
--- a/llms/gemini/README.md
+++ b/llms/gemini/README.md
@@ -1 +1 @@
-## gemini
\ No newline at end of file
+## gemini
diff --git a/llms/openai/README.md b/llms/openai/README.md
index 8748b82..af2d758 100644
--- a/llms/openai/README.md
+++ b/llms/openai/README.md
@@ -1,16 +1,25 @@
-## openai
+## openAI
+
+
+
+
+---
+
### cool resources
-* **[vscode chatgpt plugin](https://github.com/mpociot/chatgpt-vscode) (and [here](https://marketplace.visualstudio.com/items?itemName=timkmecl.chatgpt))**
-* **[scispace extension (paper explainer)](https://chrome.google.com/webstore/detail/scispace-copilot/cipccbpjpemcnijhjcdjmkjhmhniiick/related)**
+* **[vscode chatgpt plugin](https://github.com/mpociot/chatgpt-vscode)**
+* **[scispace extension (paper
+explainer)](https://chrome.google.com/webstore/detail/scispace-copilot/cipccbpjpemcnijhjcdjmkjhmhniiick/related)**
* **[fix python bugs](https://platform.openai.com/playground/p/default-fix-python-bugs?model=code-davinci-002)**
* **[explain code](https://platform.openai.com/playground/p/default-explain-code?model=code-davinci-002)**
* **[translate code](https://platform.openai.com/playground/p/default-translate-code?model=code-davinci-002)**
* **[translate sql](https://platform.openai.com/playground/p/default-sql-translate?model=code-davinci-002)**
-* **[calculate time complexity](https://platform.openai.com/playground/p/default-time-complexity?model=text-davinci-003)**
-* **[text to programmatic command](https://platform.openai.com/playground/p/default-text-to-command?model=text-davinci-003)**
+* **[calculate time
+complexity](https://platform.openai.com/playground/p/default-time-complexity?model=text-davinci-003)**
+* **[text to programmatic
+command](https://platform.openai.com/playground/p/default-text-to-command?model=text-davinci-003)**
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..b7c624d
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,68 @@
+[tool.black]
+line-length = 88
+target-version = ['py39']
+include = '\.pyi?$'
+extend-exclude = '''
+/(
+ # directories
+ \.eggs
+ | \.git
+ | \.hg
+ | \.mypy_cache
+ | \.tox
+ | \.venv
+ | venv
+ | _build
+ | buck-out
+ | build
+ | dist
+)/
+'''
+
+[tool.isort]
+profile = "black"
+multi_line_output = 3
+line_length = 88
+known_first_party = []
+known_third_party = []
+sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
+skip = ["venv", ".venv", "__pycache__"]
+
+[tool.autopep8]
+max_line_length = 88
+aggressive = 2
+experimental = true
+
+[tool.autoflake]
+remove-all-unused-imports = true
+remove-unused-variables = true
+remove-duplicate-keys = true
+ignore-init-module-imports = true
+
+[tool.flake8]
+max-line-length = 88
+extend-ignore = ["E203", "W503"]
+exclude = [
+ ".git",
+ "__pycache__",
+ "venv",
+ ".venv",
+ "build",
+ "dist",
+ ".eggs"
+]
+
+[tool.mypy]
+python_version = "3.9"
+warn_return_any = true
+warn_unused_configs = true
+disallow_untyped_defs = true
+disallow_incomplete_defs = true
+check_untyped_defs = true
+disallow_untyped_decorators = true
+no_implicit_optional = true
+warn_redundant_casts = true
+warn_unused_ignores = true
+warn_no_return = true
+warn_unreachable = true
+strict_equality = true
diff --git a/scripts/auto_fix.py b/scripts/auto_fix.py
new file mode 100755
index 0000000..4319eb5
--- /dev/null
+++ b/scripts/auto_fix.py
@@ -0,0 +1,628 @@
+#!/usr/bin/env python3
+"""
+this script fixes common quality issues including:
+- code formatting (black, isort, autopep8, autoflake)
+- markdown formatting (line length, trailing whitespace)
+- markdown link validation (internal and external)
+- python code quality issues
+- import organization
+"""
+
+import os
+import re
+import subprocess
+import time
+from pathlib import Path
+
+import requests
+
+
+class AutoFixer:
+
+ def __init__(self):
+ self.venv_path = "venv"
+ self.fixes_applied = 0
+ self.errors_encountered = 0
+ self.link_report = {
+ "total_links": 0,
+ "internal_links": 0,
+ "external_links": 0,
+ "broken_internal": 0,
+ "broken_external": 0,
+ "broken_links": [],
+ }
+
+ def fix_python_code(self) -> bool:
+ print("\n🐍 fixing python code...")
+ python_files = list(Path(".").rglob("*.py"))
+ python_files = [
+ f
+ for f in python_files
+ if not any(
+ part.startswith(".") or part in ["venv", "__pycache__", ".venv"]
+ for part in f.parts
+ )
+ ]
+
+ if not python_files:
+ print("ℹ️ no python files found to fix")
+ return True
+
+ print(f"ℹ️ found {len(python_files)} Python files to fix")
+
+ print("🔧 autoflake - removing unused imports...")
+ if self._run_autoflake(python_files):
+ self.fixes_applied += 1
+ print("✅ autoflake completed")
+ else:
+ print("⚠️ autoflake had issues")
+
+ print("🔧 autopep8 - fixing code style...")
+ if self._run_autopep8(python_files):
+ self.fixes_applied += 1
+ print("✅ autopep8 completed")
+ else:
+ print("⚠️ autopep8 had issues")
+
+ print("🔧 isort - organizing imports...")
+ if self._run_isort(python_files):
+ self.fixes_applied += 1
+ print("✅ isort completed")
+ else:
+ print("⚠️ isort had issues")
+
+ print("🔧 black - applying consistent formatting...")
+ if self._run_black(python_files):
+ self.fixes_applied += 1
+ print("✅ black completed")
+ else:
+ print("⚠️ black had issues")
+ return True
+
+ def _run_autoflake(self, python_files: list[Path]) -> bool:
+ try:
+ for file_path in python_files:
+ cmd = [
+ f"{self.venv_path}/bin/autoflake",
+ "--in-place",
+ "--remove-all-unused-imports",
+ "--remove-unused-variables",
+ str(file_path),
+ ]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ if result.returncode != 0:
+ print(f"autoflake warning for {file_path}: {result.stderr}")
+ return True
+ except Exception as e:
+ print(f"autoflake error: {e}")
+
+ def _run_autopep8(self, python_files: list[Path]) -> bool:
+ try:
+ for file_path in python_files:
+ cmd = [
+ f"{self.venv_path}/bin/autopep8",
+ "--in-place",
+ "--aggressive",
+ "--aggressive",
+ str(file_path),
+ ]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ if result.returncode != 0:
+ print(f"autopep8 warning for {file_path}: {result.stderr}")
+ return True
+ except Exception as e:
+ print(f"autopep8 error: {e}")
+
+ def _run_isort(self, python_files: list[Path]) -> bool:
+ try:
+ for file_path in python_files:
+ cmd = [f"{self.venv_path}/bin/isort", str(file_path)]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ if result.returncode != 0:
+ print(f"isort warning for {file_path}: {result.stderr}")
+ return True
+ except Exception as e:
+ print(f"isort error: {e}")
+
+ def _run_black(self, python_files: list[Path]) -> bool:
+ try:
+ import subprocess
+
+ for file_path in python_files:
+ cmd = [f"{self.venv_path}/bin/black", str(file_path)]
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ if result.returncode != 0:
+ print(f"black warning for {file_path}: {result.stderr}")
+ return True
+ except Exception as e:
+ print(f"black error: {e}")
+
+ def fix_markdown_files(self) -> bool:
+ print("\n📝 fixing markdown files...")
+ markdown_files = list(Path(".").rglob("*.md"))
+ markdown_files = [
+ f
+ for f in markdown_files
+ if not any(
+ part.startswith(".") or part in ["venv", "__pycache__", ".venv"]
+ for part in f.parts
+ )
+ ]
+
+ if not markdown_files:
+ print("ℹ️ no markdown files found to fix")
+ return True
+
+ print(f"ℹ️ found {len(markdown_files)} markdown files to fix")
+
+ print("🔗 checking markdown links...")
+ self.check_markdown_links(markdown_files)
+ self.fix_common_link_issues(markdown_files)
+
+ for file_path in markdown_files:
+ if self.fix_single_markdown_file(file_path):
+ self.fixes_applied += 1
+ return True
+
+ def check_markdown_links(self, markdown_files: list[Path]) -> None:
+ all_links = []
+
+ for file_path in markdown_files:
+ links = self.extract_links_from_markdown(file_path)
+ for link in links:
+ link["source_file"] = file_path
+ all_links.append(link)
+
+ if not all_links:
+ print("ℹ️ no links found in markdown files")
+ return
+
+ self.link_report["total_links"] = len(all_links)
+ print(f"ℹ️ found {len(all_links)} links to check")
+
+ internal_links = [link for link in all_links if not link["is_external"]]
+ self.link_report["internal_links"] = len(internal_links)
+ if internal_links:
+ print(f"🔍 checking {len(internal_links)} internal links...")
+ self.check_internal_links(internal_links)
+
+ external_links = [link for link in all_links if link["is_external"]]
+ self.link_report["external_links"] = len(external_links)
+ if external_links:
+ print(f"🌐 checking {len(external_links)} external links...")
+ self.check_external_links(external_links)
+
+ def extract_links_from_markdown(self, file_path: Path) -> list[dict]:
+ links = []
+ try:
+ with open(file_path, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ link_pattern = r"\[([^\]]+)\]\(([^)]+)\)"
+ matches = re.findall(link_pattern, content)
+
+ for text, url in matches:
+ is_external = url.startswith(("http://", "https://", "mailto:"))
+
+ links.append(
+ {
+ "text": text.strip(),
+ "url": url.strip(),
+ "is_external": is_external,
+ "line_number": self.get_line_number_for_link(
+ content, text, url
+ ),
+ }
+ )
+
+ except Exception as e:
+ print(f"❌ error reading {file_path}: {e}")
+
+ return links
+
+ def get_line_number_for_link(self, content: str, text: str, url: str) -> int:
+ lines = content.split("\n")
+ for i, line in enumerate(lines, 1):
+ if f"[{text}]({url})" in line:
+ return i
+ return 0
+
+ def check_internal_links(self, internal_links: list[dict]) -> None:
+ broken_links = []
+
+ for link in internal_links:
+ source_file = link["source_file"]
+ url = link["url"]
+ text = link["text"]
+ line_num = link["line_number"]
+
+ if url.startswith("../"):
+ target_path = source_file.parent.parent / url[3:]
+ elif url.startswith("./"):
+ target_path = source_file.parent / url[2:]
+ else:
+ target_path = source_file.parent / url
+
+ if not target_path.exists():
+ broken_link_info = {
+ "source_file": source_file,
+ "text": text,
+ "url": url,
+ "line_number": line_num,
+ "target_path": target_path,
+ "issue": "File not found",
+ "type": "internal",
+ }
+ broken_links.append(broken_link_info)
+ self.link_report["broken_links"].append(broken_link_info)
+
+ self.link_report["broken_internal"] = len(broken_links)
+ if broken_links:
+ print(f"❌ found {len(broken_links)} broken internal links:")
+ for link in broken_links:
+ print(
+ f" - {link['source_file']}:{link['line_number']} - '{link['text']}' -> {link['url']} ({link['issue']})"
+ )
+ else:
+ print("✅ all internal links are valid")
+
+ def check_external_links(self, external_links: list[dict]) -> None:
+ broken_links = []
+ checked_count = 0
+
+ for link in external_links:
+ url = link["url"]
+ text = link["text"]
+ source_file = link["source_file"]
+ line_num = link["line_number"]
+
+ try:
+ time.sleep(0.1)
+
+ response = requests.head(url, timeout=10, allow_redirects=True)
+ checked_count += 1
+
+ if response.status_code >= 400:
+ if (
+ response.status_code == 429
+ or response.status_code == 403
+ or response.status_code == 443
+ ):
+ if checked_count % 10 == 0:
+ print(
+ f" checked {checked_count}/{len(external_links)} external links..."
+ )
+ continue
+
+ broken_link_info = {
+ "source_file": source_file,
+ "text": text,
+ "url": url,
+ "line_number": line_num,
+ "status_code": response.status_code,
+ "issue": f"HTTP {response.status_code}",
+ "type": "external",
+ }
+ broken_links.append(broken_link_info)
+ self.link_report["broken_links"].append(broken_link_info)
+
+ if checked_count % 10 == 0:
+ print(
+ f" checked {checked_count}/{len(external_links)} external links..."
+ )
+
+ except requests.exceptions.RequestException as e:
+ broken_link_info = {
+ "source_file": source_file,
+ "text": text,
+ "url": url,
+ "line_number": line_num,
+ "issue": f"Connection error: {str(e)}",
+ "type": "external",
+ }
+ broken_links.append(broken_link_info)
+ self.link_report["broken_links"].append(broken_link_info)
+ except Exception as e:
+ broken_link_info = {
+ "source_file": source_file,
+ "text": text,
+ "url": url,
+ "line_number": line_num,
+ "issue": f"Error: {str(e)}",
+ "type": "external",
+ }
+ broken_links.append(broken_link_info)
+ self.link_report["broken_links"].append(broken_link_info)
+
+ self.link_report["broken_external"] = len(broken_links)
+ if broken_links:
+ print(f"❌ found {len(broken_links)} broken external links:")
+ for link in broken_links:
+ print(
+ f" - {link['source_file']}:{link['line_number']} - '{link['text']}' -> {link['url']} ({link['issue']})"
+ )
+ else:
+ print("✅ all external links are accessible")
+
+ print(f"ℹ️ checked {checked_count} external links")
+
+ def fix_common_link_issues(self, markdown_files: list[Path]) -> None:
+ print("🔧 fixing common link issues...")
+ fixed_count = 0
+
+ for file_path in markdown_files:
+ if self.fix_links_in_file(file_path):
+ fixed_count += 1
+
+ if fixed_count > 0:
+ print(f"✅ fixed links in {fixed_count} files")
+ self.fixes_applied += 1
+ else:
+ print("ℹ️ no link issues found to fix")
+
+ def fix_links_in_file(self, file_path: Path) -> bool:
+ try:
+ with open(file_path, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ changes_made = False
+ eth_pattern = r"([a-zA-Z0-9]+\.eth)"
+ if re.search(eth_pattern, content):
+ content = re.sub(eth_pattern, r"\1".replace(".eth", ""), content)
+ changes_made = True
+
+ double_space_pattern = r"\[([^\]]+)\]\(([^)]+)\)"
+
+ def fix_spaces(match):
+ text = match.group(1).strip()
+ url = match.group(2).strip()
+ if text != match.group(1) or url != match.group(2):
+ return f"[{text}]({url})"
+ return match.group(0)
+
+ new_content = re.sub(double_space_pattern, fix_spaces, content)
+ if new_content != content:
+ content = new_content
+ changes_made = True
+
+ if changes_made:
+ with open(file_path, "w", encoding="utf-8") as f:
+ f.write(content)
+ print(f" 🔧 fixed links in {file_path}")
+ return True
+
+ except Exception as e:
+ print(f"❌ error fixing links in {file_path}: {e}")
+
+ return False
+
+ def fix_single_markdown_file(self, file_path: Path) -> bool:
+ try:
+ with open(file_path, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+ lines.copy()
+ fixed_lines = []
+ changes_made = False
+ in_code_block = False
+
+ for _, line in enumerate(lines):
+ line = line.rstrip()
+
+ if not line.strip():
+ fixed_lines.append("\n")
+ continue
+
+ if line.startswith("```"):
+ in_code_block = not in_code_block
+ fixed_lines.append(line + "\n")
+ continue
+
+ if in_code_block:
+ fixed_lines.append(line + "\n")
+ continue
+
+ if line.strip().startswith(("- ", "* ", "+ ", "1. ")):
+ if len(line) > 120:
+ broken_line = self._break_list_item(line)
+ if broken_line != line:
+ changes_made = True
+ print(f" breaking long list item in {file_path}")
+ fixed_lines.append(broken_line + "\n")
+ else:
+ fixed_lines.append(line + "\n")
+ continue
+
+ if len(line) > 120:
+ broken_lines = self.break_long_line(line)
+ if broken_lines != line:
+ changes_made = True
+ print(
+ f" breaking long line in {file_path}: {len(line)} chars -> {len(broken_lines.split(chr(10))[0])} chars"
+ )
+
+ for broken_line in broken_lines.split("\n"):
+ if broken_line.strip():
+ fixed_lines.append(broken_line + "\n")
+ else:
+ fixed_lines.append("\n")
+ else:
+ fixed_lines.append(line + "\n")
+
+ if changes_made:
+ with open(file_path, "w", encoding="utf-8") as f:
+ f.writelines(fixed_lines)
+ print(f" ✅ fixed {file_path}")
+ return True
+ else:
+ print(f" ℹ️ no changes needed in {file_path}")
+
+ except Exception as e:
+ print(f"❌ error fixing {file_path}: {e}")
+ self.errors_encountered += 1
+
+ def break_long_line(self, line: str) -> str:
+ if len(line) <= 120:
+ return line
+
+ if ". " in line:
+ parts = line.split(". ")
+ if len(parts[0]) <= 120:
+ remaining = ". ".join(parts[1:])
+ if len(remaining) <= 120:
+ return parts[0] + ". " + remaining
+ else:
+ broken_remaining = self._break_at_words(remaining)
+ return parts[0] + ".\n" + broken_remaining
+ return self._break_at_words(line)
+
+ def _break_at_words(self, line: str) -> str:
+ words = line.split()
+ result = []
+ current_line = ""
+
+ for word in words:
+ if current_line and len(current_line + " " + word) > 120:
+ if current_line:
+ result.append(current_line)
+ current_line = word
+ else:
+ if current_line:
+ current_line += " " + word
+ else:
+ current_line = word
+
+ if current_line:
+ result.append(current_line)
+
+ return "\n".join(result)
+
+ def _break_list_item(self, line: str) -> str:
+ marker_end = 0
+ for i, char in enumerate(line):
+ if char in "-*+" or (char.isdigit() and line[i + 1 : i + 3] == ". "):
+ marker_end = line.find(" ", i)
+ if marker_end == -1:
+ marker_end = len(line)
+ break
+
+ if marker_end == 0:
+ return self.break_long_line(line)
+
+ marker = line[: marker_end + 1]
+ content = line[marker_end + 1 :]
+
+ if len(content) <= 120 - len(marker):
+ return line
+
+ broken_content = self._break_at_words(content)
+ if "\n" in broken_content:
+ indent = " " * len(marker)
+ lines = broken_content.split("\n")
+ result = [marker + lines[0]]
+ for continuation_line in lines[1:]:
+ result.append(indent + continuation_line)
+ return "\n".join(result)
+
+ return line
+
+ def fix_trailing_whitespace(self) -> bool:
+ print("\n🧹 fixing trailing whitespace...")
+
+ text_extensions = {".py", ".md", ".txt", ".rst", ".yml", ".yaml", ".json"}
+ files_fixed = 0
+
+ for root, dirs, files in os.walk("."):
+ dirs[:] = [
+ d
+ for d in dirs
+ if not d.startswith(".")
+ and d not in ["venv", "__pycache__", ".venv", "node_modules"]
+ ]
+
+ for file in files:
+ if any(file.endswith(ext) for ext in text_extensions):
+ file_path = os.path.join(root, file)
+ if self.fix_file_trailing_whitespace(file_path):
+ files_fixed += 1
+
+ if files_fixed > 0:
+ self.fixes_applied += 1
+ print(f"ℹ️ fixed trailing whitespace in {files_fixed} files")
+ return True
+
+ def fix_file_trailing_whitespace(self, file_path: str) -> bool:
+ try:
+ with open(file_path, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+ original_lines = lines.copy()
+ fixed_lines = []
+
+ for line in lines:
+ fixed_line = line.rstrip() + "\n"
+ fixed_lines.append(fixed_line)
+
+ if fixed_lines != original_lines:
+ with open(file_path, "w", encoding="utf-8") as f:
+ f.writelines(fixed_lines)
+ print(f"ℹ️ fixed trailing whitespace in {file_path}")
+ return True
+
+ except Exception as e:
+ print(f"warning: could not fix {file_path}: {e}")
+
+ def run_all_fixes(self) -> bool:
+ print("🚀 starting auto-fix process...")
+ success = True
+ success &= self.fix_trailing_whitespace()
+ success &= self.fix_python_code()
+ success &= self.fix_markdown_files()
+ return success
+
+ def print_summary(self):
+ print("\n" + "=" * 50)
+ print("🎯 AUTO-FIX SUMMARY")
+ print("=" * 50)
+ print(f"✅ fixes applied: {self.fixes_applied}")
+ print(f"❌ errors encountered: {self.errors_encountered}")
+
+ if self.link_report["total_links"] > 0:
+ print("\n🔗 LINK CHECK SUMMARY")
+ print("-" * 30)
+ print(f"📊 total links found: {self.link_report['total_links']}")
+ print(f"🔍 internal links: {self.link_report['internal_links']}")
+ print(f"🌐 external links: {self.link_report['external_links']}")
+ print(f"❌ broken internal: {self.link_report['broken_internal']}")
+ print(f"❌ broken external: {self.link_report['broken_external']}")
+
+ if self.link_report["broken_links"]:
+ print(f"\n⚠️ broken links found:")
+ for link in self.link_report["broken_links"]:
+ print(
+ f" - {link['source_file']}:{link['line_number']} - '{link['text']}' -> {link['url']} ({link['issue']})"
+ )
+
+ if (
+ self.errors_encountered == 0
+ and self.link_report["broken_internal"] == 0
+ and self.link_report["broken_external"] == 0
+ ):
+ print("\n🎉 all fixes completed successfully and all links are working!")
+ elif self.errors_encountered == 0:
+ print(f"\n✅ all fixes completed successfully!")
+ print(
+ f"⚠️ but {self.link_report['broken_internal'] + self.link_report['broken_external']} broken links were found"
+ )
+ else:
+ print(f"\n⚠️ {self.errors_encountered} errors occurred during fixing")
+
+
+def main():
+ fixer = AutoFixer()
+ fixer.run_all_fixes()
+ fixer.print_summary()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/requirements.txt b/scripts/requirements.txt
new file mode 100644
index 0000000..11f2f49
--- /dev/null
+++ b/scripts/requirements.txt
@@ -0,0 +1,10 @@
+textstat>=0.7.3
+requests>=2.28.0
+beautifulsoup4>=4.11.0
+markdown>=3.4.0
+black>=23.0.0
+isort>=5.12.0
+flake8>=6.0.0
+mypy>=1.0.0
+autopep8>=2.0.0
+autoflake>=2.0.0