chores: refactor for the new ai research, add linter, gh action, etc (#27)

This commit is contained in:
Marina von Steinkirch, PhD 2025-08-13 21:49:46 +08:00 committed by von-steinkirch
parent fb4ab80dc3
commit d5467e559f
40 changed files with 5177 additions and 2476 deletions

0
.github/.keep vendored Normal file
View file

68
.github/workflows/auto-fix.yml vendored Normal file
View file

@ -0,0 +1,68 @@
name: 👾 auto-fix code quality
on:
pull_request:
branches: [ main, master ]
push:
branches: [ main, master ]
workflow_dispatch: # allow manual triggering
jobs:
auto-fix:
runs-on: ubuntu-latest
steps:
- name: checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # full history for better diff detection
- name: set up python
uses: actions/setup-python@v4
with:
python-version: '3.9'
cache: 'pip'
- name: install dependencies
run: |
python -m pip install --upgrade pip
pip install -r scripts/requirements.txt
- name: install code quality tools
run: |
pip install black isort autopep8 autoflake
- name: run auto-fix script
run: |
python scripts/auto_fix.py
- name: check for changes
id: check_changes
run: |
if [ -n "$(git status --porcelain)" ]; then
echo "changes=true" >> $GITHUB_OUTPUT
echo "files were modified by auto-fix script"
git status --porcelain
else
echo "changes=false" >> $GITHUB_OUTPUT
echo "no files were modified"
fi
- name: commit and push changes (if any)
if: steps.check_changes.outputs.changes == 'true'
run: |
git config --local user.email "action@github.com"
git config --local user.name "github action"
git add -a
git commit -m "🔧 auto-fix code quality issues
- applied black formatting
- organized imports with isort
- fixed code style with autopep8
- removed unused imports with autoflake
- fixed markdown formatting
- validated and fixed links
- removed trailing whitespace
auto-generated by github actions"
git push

204
.gitignore vendored Normal file
View file

@ -0,0 +1,204 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Django stuff (keeping in case of web components)
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff (keeping in case of web components)
instance/
.webassets-cache
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# poetry
poetry.lock
# pdm
pdm.lock
.pdm.toml
# PEP 582
__pypackages__/
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
.idea/
*.iml
*.ipr
*.iws
# VS Code
.vscode/
# macOS
.DS_Store
.AppleDouble
.LSOverride
Icon
._*
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# Windows
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
*.stackdump
[Dd]esktop.ini
$RECYCLE.BIN/
*.cab
*.msi
*.msix
*.msm
*.msp
*.lnk
# Linux
*~
.fuse_hidden*
.directory
.Trash-*
.nfs*
# Machine Learning specific
*.pkl
*.pickle
*.joblib
*.h5
*.hdf5
*.model
*.weights
*.ckpt
*.pth
*.pt
*.onnx
*.tflite
*.pb
# Data files
*.csv
*.json
*.parquet
*.feather
*.hdf
*.xlsx
*.xls
# Large model files
models/
checkpoints/
runs/
logs/
wandb/
# Environment variables
.env.local
.env.development
.env.test
.env.production
# IDE specific
*.swp
*.swo

View file

@ -1,12 +1,15 @@
## quantum ai: training energy-based-models using openai ## quantum ai: training energy-based-models using openAI
<br> <br>
#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in energy-based-models](https://arxiv.org/pdf/1903.08689.pdf) #### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in
energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
<br> <br>
---
### installing ### installing
<br> <br>
@ -19,7 +22,8 @@ brew install pkg-config
<br> <br>
* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this problem (`PMIX ERROR: ERROR`) that can be fixed with: * there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this
problem (`PMIX ERROR: ERROR`) that can be fixed with:
<br> <br>
@ -40,7 +44,8 @@ pip install -r requirements.txt
``` ```
<br> <br>
* note that this is an adapted requirement file since the **[openai's original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct * note that this is an adapted requirement file since the **[openai's
original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
* finally, download and install **[mujoco](https://www.roboti.us/index.html)** * finally, download and install **[mujoco](https://www.roboti.us/index.html)**
* you will also need to register for a license, which asks for a machine ID * you will also need to register for a license, which asks for a machine ID
* the documentation on the website is incomplete, so just download the suggested script and run: * the documentation on the website is incomplete, so just download the suggested script and run:
@ -64,7 +69,8 @@ mv getid_osx getid_osx.dms
<br> <br>
* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder `cachedir`: * download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder
`cachedir`:
<br> <br>
@ -78,7 +84,8 @@ mkdir cachedir
<br> <br>
* openai's original code contains **[hardcoded constants that only work on Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)** * openai's original code contains **[hardcoded constants that only work on
Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
* i changed this to a constant (`ROOT_DIR = "./results"`) in the top of `data.py` * i changed this to a constant (`ROOT_DIR = "./results"`) in the top of `data.py`
<br> <br>
@ -87,7 +94,8 @@ mkdir cachedir
<br> <br>
* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased substantially by using multiple different workers by running each command: * all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased
substantially by using multiple different workers by running each command:
<br> <br>
@ -102,7 +110,8 @@ mpiexec -n <worker_num> <command>
<br> <br>
``` ```
python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
--zero_kl --replay_batch --large_model
``` ```
* this should generate the following output: * this should generate the following output:
@ -112,7 +121,8 @@ python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_si
```bash ```bash
Instructions for updating: Instructions for updating:
Use tf.gfile.GFile. Use tf.gfile.GFile.
2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization(). 2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is
deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
64 batch size 64 batch size
Local rank: 0 1 Local rank: 0 1
Loading data... Loading data...
@ -121,11 +131,15 @@ Files already downloaded and verified
Files already downloaded and verified Files already downloaded and verified
Files already downloaded and verified Files already downloaded and verified
Done loading... Done loading...
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. WARNING:tensorflow:From
/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263:
colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating: Instructions for updating:
Colocations handled automatically by placer. Colocations handled automatically by placer.
Building graph... Building graph...
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. WARNING:tensorflow:From
/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from
tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating: Instructions for updating:
Use tf.cast instead. Use tf.cast instead.
Finished processing loop construction ... Finished processing loop construction ...
@ -136,16 +150,36 @@ Model has a total of 7567880 parameters
Initializing variables... Initializing variables...
Start broadcast Start broadcast
End broadcast End broadcast
Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff: 0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent: 2.272536277770996, l Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff:
oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first: 0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818, context_0/res_optim_res_c1/ 0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent:
cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10, context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10, context_0/res_optim_res_c2/gb:0: 6.27235652306268 2.272536277770996, l
3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11, context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437, context_0/res_1_res_c2/g:0 oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first:
: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0: 1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0: 6.868505764145993e-10, context_0/res_2 0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818,
_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0: 8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0: 0.0194275379180908 context_0/res_optim_res_c1/
2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10, context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11, context_0/res_3_res_c2/gb:0: 7.75612463 cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10,
1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0: 4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0: 0.34806951880455017, context_0/res_4_res_c2/g: context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10,
0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0: 5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0: 3.807749116013781e-10, context_0/res_5 context_0/res_optim_res_c2/gb:0: 6.27235652306268
_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0: 7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039, context_0/fc5/ 3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11,
context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437,
context_0/res_1_res_c2/g:0
: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0:
1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0:
6.868505764145993e-10, context_0/res_2
_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0:
8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0:
0.0194275379180908
2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10,
context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11,
context_0/res_3_res_c2/gb:0: 7.75612463
1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0:
4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0:
0.34806951880455017, context_0/res_4_res_c2/g:
0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0:
5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0:
3.807749116013781e-10, context_0/res_5
_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0:
7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039,
context_0/fc5/
b:0: 0.4506262540817261, b:0: 0.4506262540817261,
................................................................................................................................ ................................................................................................................................
@ -159,7 +193,8 @@ Inception score of 1.2397289276123047 with std of 0.0
<br> <br>
``` ```
python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
--zero_kl --replay_batch --cclass
``` ```
<br> <br>
@ -169,7 +204,8 @@ python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size
<br> <br>
``` ```
python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path> python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01
--replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path>
``` ```
<br> <br>
@ -179,7 +215,8 @@ python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=3
<br> <br>
``` ```
python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path> python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass
--zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path>
``` ```
<br> <br>
@ -197,7 +234,8 @@ python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0
python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act
``` ```
* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by different settings of task flag in the file * the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by
different settings of task flag in the file
* for example, to visualize cross class mappings in cifar-10, you can run: * for example, to visualize cross class mappings in cifar-10, you can run:
<br> <br>
@ -217,7 +255,8 @@ python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resu
<br> <br>
``` ```
python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200
--large_model --svhnmix --cclass=False
``` ```
<br> <br>
@ -227,7 +266,8 @@ python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_
<br> <br>
``` ```
python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd steps> --num_steps=10 --lival=<li bound value> --wider_model python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd
steps> --num_steps=10 --lival=<li bound value> --wider_model
``` ```
<br> <br>
@ -236,12 +276,14 @@ python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=
<br> <br>
* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in `cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below: * to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in
`cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
<br> <br>
``` ```
python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act
--cond_pos --replay_batch -cclass
``` ```
<br> <br>
@ -249,5 +291,7 @@ python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps
* once models are trained, they can be sampled from jointly by running: * once models are trained, they can be sampled from jointly by running:
``` ```
python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos> --exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot> --resume_pos=<resume_pos> python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos>
--exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot>
--resume_pos=<resume_pos>
``` ```

View file

@ -1,40 +1,65 @@
import tensorflow as tf
import math import math
import os.path as osp
import numpy as np
import tensorflow as tf
from data import Cifar10, DSprites, Mnist
from hmc import hmc from hmc import hmc
from models import DspritesNet, MnistNet, ResNet32, ResNet32Large, ResNet32Wider
from tensorflow.python.platform import flags from tensorflow.python.platform import flags
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
from data import Cifar10, Mnist, DSprites
from scipy.misc import logsumexp
from scipy.misc import imsave
from utils import optimistic_restore
import os.path as osp
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from utils import optimistic_restore
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single') flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss') flags.DEFINE_string(
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') "dataset", "cifar10", "cifar10 or mnist or dsprites or 2d or toy Gauss"
flags.DEFINE_string('exp', 'default', 'name of experiments') )
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel') flags.DEFINE_string(
flags.DEFINE_integer('batch_size', 16, 'Size of inputs') "logdir", "cachedir", "location where log of experiments will be stored"
flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from') )
flags.DEFINE_string("exp", "default", "name of experiments")
flags.DEFINE_integer(
"data_workers", 5, "Number of different data workers to load data in parallel"
)
flags.DEFINE_integer("batch_size", 16, "Size of inputs")
flags.DEFINE_string("resume_iter", "-1", "iteration to resume training from")
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions') flags.DEFINE_bool(
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.') "max_pool",
flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais') False,
flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian') "Whether or not to use max pooling rather than strided convolutions",
flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box') )
flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model') flags.DEFINE_integer(
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not') "num_filters",
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') 64,
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') "number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') )
flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not') flags.DEFINE_integer("pdist", 10, "number of intermediate distributions for ais")
flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not') flags.DEFINE_integer("gauss_dim", 500, "dimensions for modeling Gaussian")
flags.DEFINE_bool('large_model', False, 'Use large model to evaluate') flags.DEFINE_integer(
flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate') "rescale", 1, "factor to rescale input outside of normal (0, 1) box"
flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps') )
flags.DEFINE_float(
"temperature", 1, "temperature at which to compute likelihood of model"
)
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
flags.DEFINE_bool(
"cclass",
False,
"Whether to evaluate the log likelihood of conditional model or not",
)
flags.DEFINE_bool(
"single",
False,
"Whether to evaluate the log likelihood of conditional model or not",
)
flags.DEFINE_bool("large_model", False, "Use large model to evaluate")
flags.DEFINE_bool("wider_model", False, "Use large model to evaluate")
flags.DEFINE_float("alr", 0.0045, "Learning rate to use for HMC steps")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
@ -45,11 +70,12 @@ label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
def unscale_im(im): def unscale_im(im):
return (255 * np.clip(im, 0, 1)).astype(np.uint8) return (255 * np.clip(im, 0, 1)).astype(np.uint8)
def gauss_prob_log(x, prec=1.0): def gauss_prob_log(x, prec=1.0):
nh = float(np.prod([s.value for s in x.get_shape()[1:]])) nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec)) norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2.0 * prec
return norm_constant_log + prob_density_log return norm_constant_log + prob_density_log
@ -73,23 +99,36 @@ def model_prob_log(x, e_func, weights, temp):
def bridge_prob_neg_log(alpha, x, e_func, weights, temp): def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
if FLAGS.dataset == "gauss": if FLAGS.dataset == "gauss":
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature) norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(
x, prec=FLAGS.temperature
)
else: else:
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp) norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * model_prob_log(
# Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely x, e_func, weights, temp
)
# Add an additional log likelihood penalty so that points outside of (0,
# 1) box are *highly* unlikely
if FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss': oob_prob = tf.reduce_sum(
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1]) tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1]
elif FLAGS.dataset == 'mnist': )
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2]) elif FLAGS.dataset == "mnist":
oob_prob = tf.reduce_sum(
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1, 2]
)
else: else:
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3]) oob_prob = tf.reduce_sum(
tf.square(100 * (x - tf.clip_by_value(x, 0.0, FLAGS.rescale))),
axis=[1, 2, 3],
)
return -norm_prob + oob_prob return -norm_prob + oob_prob
def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10): def ancestral_sample(
e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10
):
if FLAGS.dataset == "2d": if FLAGS.dataset == "2d":
x = tf.placeholder(tf.float32, shape=(None, 2)) x = tf.placeholder(tf.float32, shape=(None, 2))
elif FLAGS.dataset == "gauss": elif FLAGS.dataset == "gauss":
@ -130,41 +169,46 @@ def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_
def main(): def main():
# Initialize dataset # Initialize dataset
if FLAGS.dataset == 'cifar10': if FLAGS.dataset == "cifar10":
dataset = Cifar10(train=False, rescale=FLAGS.rescale) dataset = Cifar10(train=False, rescale=FLAGS.rescale)
channel_num = 3 channel_num = 3
dim_input = 32 * 32 * 3 32 * 32 * 3
elif FLAGS.dataset == 'imagenet': elif FLAGS.dataset == "imagenet":
dataset = ImagenetClass() dataset = ImagenetClass()
channel_num = 3 channel_num = 3
dim_input = 64 * 64 * 3 64 * 64 * 3
elif FLAGS.dataset == 'mnist': elif FLAGS.dataset == "mnist":
dataset = Mnist(train=False, rescale=FLAGS.rescale) dataset = Mnist(train=False, rescale=FLAGS.rescale)
channel_num = 1 channel_num = 1
dim_input = 28 * 28 * 1 28 * 28 * 1
elif FLAGS.dataset == 'dsprites': elif FLAGS.dataset == "dsprites":
dataset = DSprites() dataset = DSprites()
channel_num = 1 channel_num = 1
dim_input = 64 * 64 * 1 64 * 64 * 1
elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss': elif FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
dataset = Box2D() dataset = Box2D()
dim_output = 1 data_loader = DataLoader(
data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True) dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.data_workers,
drop_last=False,
shuffle=True,
)
if FLAGS.dataset == 'mnist': if FLAGS.dataset == "mnist":
model = MnistNet(num_channels=channel_num) model = MnistNet(num_channels=channel_num)
elif FLAGS.dataset == 'cifar10': elif FLAGS.dataset == "cifar10":
if FLAGS.large_model: if FLAGS.large_model:
model = ResNet32Large(num_filters=128) model = ResNet32Large(num_filters=128)
elif FLAGS.wider_model: elif FLAGS.wider_model:
model = ResNet32Wider(num_filters=192) model = ResNet32Wider(num_filters=192)
else: else:
model = ResNet32(num_channels=channel_num, num_filters=128) model = ResNet32(num_channels=channel_num, num_filters=128)
elif FLAGS.dataset == 'dsprites': elif FLAGS.dataset == "dsprites":
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters) model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
weights = model.construct_weights('context_{}'.format(0)) weights = model.construct_weights("context_{}".format(0))
config = tf.ConfigProto() config = tf.ConfigProto()
sess = tf.Session(config=config) sess = tf.Session(config=config)
@ -173,8 +217,8 @@ def main():
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
logdir = osp.join(FLAGS.logdir, FLAGS.exp) logdir = osp.join(FLAGS.logdir, FLAGS.exp)
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
resume_itr = FLAGS.resume_iter FLAGS.resume_iter
if FLAGS.resume_iter != "-1": if FLAGS.resume_iter != "-1":
optimistic_restore(sess, model_file) optimistic_restore(sess, model_file)
@ -182,14 +226,17 @@ def main():
print("WARNING, YOU ARE NOT LOADING A SAVE FILE") print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
# saver.restore(sess, model_file) # saver.restore(sess, model_file)
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature) chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(
model, weights, FLAGS.batch_size, temp=FLAGS.temperature
)
print("Finished constructing ancestral sample ...................") print("Finished constructing ancestral sample ...................")
if FLAGS.dataset != "gauss": if FLAGS.dataset != "gauss":
comb_weights_cum = []
batch_size = tf.shape(x_init)[0] batch_size = tf.shape(x_init)[0]
label_tiled = tf.tile(label_default, (batch_size, 1)) label_tiled = tf.tile(label_default, (batch_size, 1))
e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled) e_compute = -FLAGS.temperature * model.forward(
x_init, weights, label=label_tiled
)
e_pos_list = [] e_pos_list = []
for data_corrupt, data, label_gt in tqdm(data_loader): for data_corrupt, data, label_gt in tqdm(data_loader):
@ -205,44 +252,75 @@ def main():
alr = 0.0085 alr = 0.0085
elif FLAGS.dataset == "mnist": elif FLAGS.dataset == "mnist":
alr = 0.0065 alr = 0.0065
#90 alr = 0.0035 # 90 alr = 0.0035
else: else:
# alr = 0.0125 # alr = 0.0125
if FLAGS.rescale == 8: if FLAGS.rescale == 8:
alr = 0.0085 alr = 0.0085
else: else:
alr = 0.0045 alr = 0.0045
# #
for i in range(1): for i in range(1):
tot_weight = 0 tot_weight = 0
for j in tqdm(range(1, FLAGS.pdist+1)): for j in tqdm(range(1, FLAGS.pdist + 1)):
if j == 1: if j == 1:
if FLAGS.dataset == "cifar10": if FLAGS.dataset == "cifar10":
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)) x_curr = np.random.uniform(
0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)
)
elif FLAGS.dataset == "gauss": elif FLAGS.dataset == "gauss":
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)) x_curr = np.random.uniform(
0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)
)
elif FLAGS.dataset == "mnist": elif FLAGS.dataset == "mnist":
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)) x_curr = np.random.uniform(
0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)
)
else: else:
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2)) x_curr = np.random.uniform(
0, FLAGS.rescale, size=(FLAGS.batch_size, 2)
)
alpha_prev = (j-1) / FLAGS.pdist alpha_prev = (j - 1) / FLAGS.pdist
alpha_new = j / FLAGS.pdist alpha_new = j / FLAGS.pdist
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))}) cweight, x_curr = sess.run(
[chain_weights, x],
{
a_prev: alpha_prev,
a_new: alpha_new,
x_init: x_curr,
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
},
)
tot_weight = tot_weight + cweight tot_weight = tot_weight + cweight
print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight)) print(
"Total values of lower value based off forward sampling",
np.mean(tot_weight),
np.std(tot_weight),
)
tot_weight = 0 tot_weight = 0
for j in tqdm(range(FLAGS.pdist, 0, -1)): for j in tqdm(range(FLAGS.pdist, 0, -1)):
alpha_new = (j-1) / FLAGS.pdist alpha_new = (j - 1) / FLAGS.pdist
alpha_prev = j / FLAGS.pdist alpha_prev = j / FLAGS.pdist
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))}) cweight, x_curr = sess.run(
[chain_weights, x],
{
a_prev: alpha_prev,
a_new: alpha_new,
x_init: x_curr,
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
},
)
tot_weight = tot_weight - cweight tot_weight = tot_weight - cweight
print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight)) print(
"Total values of upper value based off backward sampling",
np.mean(tot_weight),
np.std(tot_weight),
)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -14,223 +14,244 @@
# ============================================================================== # ==============================================================================
"""Adam for TensorFlow.""" """Adam for TensorFlow."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import (
from tensorflow.python.ops import math_ops control_flow_ops,
from tensorflow.python.ops import resource_variable_ops math_ops,
from tensorflow.python.ops import state_ops resource_variable_ops,
from tensorflow.python.training import optimizer state_ops,
from tensorflow.python.training import training_ops )
from tensorflow.python.training import optimizer, training_ops
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
import tensorflow as tf
@tf_export("train.AdamOptimizer") @tf_export("train.AdamOptimizer")
class AdamOptimizer(optimizer.Optimizer): class AdamOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adam algorithm. """Optimizer that implements the Adam algorithm.
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)). ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam"):
"""Construct a new Adam optimizer.
Initialization:
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
$$t := 0 \text{(Initialize timestep)}$$
The update rule for `variable` with gradient `g` uses an optimization
described at the end of section2 of the paper:
$$t := t + 1$$
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
The default value of 1e-8 for epsilon might not be a good default in
general. For example, when training an Inception network on ImageNet a
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
formulation just before Section 2.1 of the Kingma and Ba paper rather than
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
hat" in the paper.
The sparse implementation of this algorithm (used when the gradient is an
IndexedSlices object, typically because of `tf.gather` or an embedding
lookup in the forward pass) does apply momentum to variable slices even if
they were not used in the forward pass (meaning they have a gradient equal
to zero). Momentum decay (beta1) is also applied to the entire momentum
accumulator. This means that the sparse behavior is equivalent to the dense
behavior (in contrast to some momentum implementations which ignore momentum
unless a variable slice was actually used).
Args:
learning_rate: A Tensor or a floating point value. The learning rate.
beta1: A float value or a constant float tensor.
The exponential decay rate for the 1st moment estimates.
beta2: A float value or a constant float tensor.
The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just before
Section 2.1), not the epsilon in Algorithm 1 of the paper.
use_locking: If True use locks for update operations.
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
@compatibility(eager)
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
`epsilon` can each be a callable that takes no arguments and returns the
actual value to use. This can be useful for changing these values across
different invocations of optimizer functions.
@end_compatibility
""" """
super(AdamOptimizer, self).__init__(use_locking, name)
self._lr = learning_rate
self._beta1 = beta1
self._beta2 = beta2
self._epsilon = epsilon
# Tensor versions of the constructor arguments, created in _prepare(). def __init__(
self._lr_t = None self,
self._beta1_t = None learning_rate=0.001,
self._beta2_t = None beta1=0.9,
self._epsilon_t = None beta2=0.999,
epsilon=1e-8,
use_locking=False,
name="Adam",
):
"""Construct a new Adam optimizer.
# Created in SparseApply if needed. Initialization:
self._updated_lr = None
def _get_beta_accumulators(self): $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
with ops.init_scope(): $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
if context.executing_eagerly(): $$t := 0 \text{(Initialize timestep)}$$
graph = None
else:
graph = ops.get_default_graph()
return (self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph))
def _create_slots(self, var_list): The update rule for `variable` with gradient `g` uses an optimization
# Create the beta1 and beta2 accumulators on the same device as the first described at the end of section2 of the paper:
# variable. Sort the var_list to make sure this device is consistent across
# workers (these need to go on the same PS, otherwise some updates are
# silently ignored).
first_var = min(var_list, key=lambda x: x.name)
self._create_non_slot_variable(initial_value=self._beta1,
name="beta1_power",
colocate_with=first_var)
self._create_non_slot_variable(initial_value=self._beta2,
name="beta2_power",
colocate_with=first_var)
# Create slots for the first and second moments. $$t := t + 1$$
for v in var_list: $$lr_t := \text{learning\\_rate} * \\sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
self._zeros_slot(v, "m", self._name)
self._zeros_slot(v, "v", self._name)
def _prepare(self): $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
lr = self._call_if_callable(self._lr) $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
beta1 = self._call_if_callable(self._beta1) $$variable := variable - lr_t * m_t / (\\sqrt{v_t} + \\epsilon)$$
beta2 = self._call_if_callable(self._beta2)
epsilon = self._call_if_callable(self._epsilon)
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") The default value of 1e-8 for epsilon might not be a good default in
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") general. For example, when training an Inception network on ImageNet a
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") formulation just before Section 2.1 of the Kingma and Ba paper rather than
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
hat" in the paper.
def _apply_dense(self, grad, var): The sparse implementation of this algorithm (used when the gradient is an
m = self.get_slot(var, "m") IndexedSlices object, typically because of `tf.gather` or an embedding
v = self.get_slot(var, "v") lookup in the forward pass) does apply momentum to variable slices even if
beta1_power, beta2_power = self._get_beta_accumulators() they were not used in the forward pass (meaning they have a gradient equal
to zero). Momentum decay (beta1) is also applied to the entire momentum
accumulator. This means that the sparse behavior is equivalent to the dense
behavior (in contrast to some momentum implementations which ignore momentum
unless a variable slice was actually used).
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1 Args:
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds) learning_rate: A Tensor or a floating point value. The learning rate.
# Clip gradients by 3 std beta1: A float value or a constant float tensor.
return training_ops.apply_adam( The exponential decay rate for the 1st moment estimates.
var, m, v, beta2: A float value or a constant float tensor.
math_ops.cast(beta1_power, var.dtype.base_dtype), The exponential decay rate for the 2nd moment estimates.
math_ops.cast(beta2_power, var.dtype.base_dtype), epsilon: A small constant for numerical stability. This epsilon is
math_ops.cast(self._lr_t, var.dtype.base_dtype), "epsilon hat" in the Kingma and Ba paper (in the formula just before
math_ops.cast(self._beta1_t, var.dtype.base_dtype), Section 2.1), not the epsilon in Algorithm 1 of the paper.
math_ops.cast(self._beta2_t, var.dtype.base_dtype), use_locking: If True use locks for update operations.
math_ops.cast(self._epsilon_t, var.dtype.base_dtype), name: Optional name for the operations created when applying gradients.
grad, use_locking=self._use_locking).op Defaults to "Adam".
def _resource_apply_dense(self, grad, var): @compatibility(eager)
m = self.get_slot(var, "m") When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
v = self.get_slot(var, "v") `epsilon` can each be a callable that takes no arguments and returns the
beta1_power, beta2_power = self._get_beta_accumulators() actual value to use. This can be useful for changing these values across
return training_ops.resource_apply_adam( different invocations of optimizer functions.
var.handle, m.handle, v.handle, @end_compatibility
math_ops.cast(beta1_power, grad.dtype.base_dtype), """
math_ops.cast(beta2_power, grad.dtype.base_dtype), super(AdamOptimizer, self).__init__(use_locking, name)
math_ops.cast(self._lr_t, grad.dtype.base_dtype), self._lr = learning_rate
math_ops.cast(self._beta1_t, grad.dtype.base_dtype), self._beta1 = beta1
math_ops.cast(self._beta2_t, grad.dtype.base_dtype), self._beta2 = beta2
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), self._epsilon = epsilon
grad, use_locking=self._use_locking)
def _apply_sparse_shared(self, grad, var, indices, scatter_add): # Tensor versions of the constructor arguments, created in _prepare().
beta1_power, beta2_power = self._get_beta_accumulators() self._lr_t = None
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) self._beta1_t = None
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) self._beta2_t = None
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) self._epsilon_t = None
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t,
use_locking=self._use_locking)
with ops.control_dependencies([m_t]):
m_t = scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
with ops.control_dependencies([v_t]):
v_t = scatter_add(v, indices, v_scaled_g_values)
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(var,
lr * m_t / (v_sqrt + epsilon_t),
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
def _apply_sparse(self, grad, var): # Created in SparseApply if needed.
return self._apply_sparse_shared( self._updated_lr = None
grad.values, var, grad.indices,
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
x, i, v, use_locking=self._use_locking))
def _resource_scatter_add(self, x, i, v): def _get_beta_accumulators(self):
with ops.control_dependencies( with ops.init_scope():
[resource_variable_ops.resource_scatter_add( if context.executing_eagerly():
x.handle, i, v)]): graph = None
return x.value() else:
graph = ops.get_default_graph()
return (
self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph),
)
def _resource_apply_sparse(self, grad, var, indices): def _create_slots(self, var_list):
return self._apply_sparse_shared( # Create the beta1 and beta2 accumulators on the same device as the first
grad, var, indices, self._resource_scatter_add) # variable. Sort the var_list to make sure this device is consistent across
# workers (these need to go on the same PS, otherwise some updates are
# silently ignored).
first_var = min(var_list, key=lambda x: x.name)
self._create_non_slot_variable(
initial_value=self._beta1, name="beta1_power", colocate_with=first_var
)
self._create_non_slot_variable(
initial_value=self._beta2, name="beta2_power", colocate_with=first_var
)
def _finish(self, update_ops, name_scope): # Create slots for the first and second moments.
# Update the power accumulators. for v in var_list:
with ops.control_dependencies(update_ops): self._zeros_slot(v, "m", self._name)
beta1_power, beta2_power = self._get_beta_accumulators() self._zeros_slot(v, "v", self._name)
with ops.colocate_with(beta1_power):
update_beta1 = beta1_power.assign( def _prepare(self):
beta1_power * self._beta1_t, use_locking=self._use_locking) lr = self._call_if_callable(self._lr)
update_beta2 = beta2_power.assign( beta1 = self._call_if_callable(self._beta1)
beta2_power * self._beta2_t, use_locking=self._use_locking) beta2 = self._call_if_callable(self._beta2)
return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], epsilon = self._call_if_callable(self._epsilon)
name=name_scope)
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
def _apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
# Clip gradients by 3 std
return training_ops.apply_adam(
var,
m,
v,
math_ops.cast(beta1_power, var.dtype.base_dtype),
math_ops.cast(beta2_power, var.dtype.base_dtype),
math_ops.cast(self._lr_t, var.dtype.base_dtype),
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
grad,
use_locking=self._use_locking,
).op
def _resource_apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.resource_apply_adam(
var.handle,
m.handle,
v.handle,
math_ops.cast(beta1_power, grad.dtype.base_dtype),
math_ops.cast(beta2_power, grad.dtype.base_dtype),
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
grad,
use_locking=self._use_locking,
)
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
lr = lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
with ops.control_dependencies([m_t]):
m_t = scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
with ops.control_dependencies([v_t]):
v_t = scatter_add(v, indices, v_scaled_g_values)
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(
var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking
)
return control_flow_ops.group(*[var_update, m_t, v_t])
def _apply_sparse(self, grad, var):
return self._apply_sparse_shared(
grad.values,
var,
grad.indices,
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
x, i, v, use_locking=self._use_locking
),
)
def _resource_scatter_add(self, x, i, v):
with ops.control_dependencies(
[resource_variable_ops.resource_scatter_add(x.handle, i, v)]
):
return x.value()
def _resource_apply_sparse(self, grad, var, indices):
return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add)
def _finish(self, update_ops, name_scope):
# Update the power accumulators.
with ops.control_dependencies(update_ops):
beta1_power, beta2_power = self._get_beta_accumulators()
with ops.colocate_with(beta1_power):
update_beta1 = beta1_power.assign(
beta1_power * self._beta1_t, use_locking=self._use_locking
)
update_beta2 = beta2_power.assign(
beta2_power * self._beta2_t, use_locking=self._use_locking
)
return control_flow_ops.group(
*update_ops + [update_beta1, update_beta2], name=name_scope
)

View file

@ -1,42 +1,48 @@
from tensorflow.python.platform import flags
from tensorflow.contrib.data.python.ops import batching, threadpool
import tensorflow as tf
import json import json
from torch.utils.data import Dataset
import pickle
import os.path as osp
import os import os
import numpy as np import os.path as osp
import pickle
import time import time
from scipy.misc import imread, imresize
from skimage.color import rgb2grey import numpy as np
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder import tensorflow as tf
from torchvision import transforms
from imagenet_preprocessing import ImagenetPreprocessor
import torch import torch
import torchvision import torchvision
from imagenet_preprocessing import ImagenetPreprocessor
from scipy.misc import imread, imresize
from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.platform import flags
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, ImageFolder
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
ROOT_DIR = "./results" ROOT_DIR = "./results"
# Dataset Options # Dataset Options
flags.DEFINE_string('dsprites_path', flags.DEFINE_string(
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', "dsprites_path",
'path to dsprites characters') "/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image') "path to dsprites characters",
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes') )
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes') flags.DEFINE_string(
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects') "imagenet_datadir", "/root/imagenet_big", "whether cutoff should always in image"
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects') )
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects') flags.DEFINE_bool("dshape_only", False, "fix all factors except for shapes")
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images') flags.DEFINE_bool("dpos_only", False, "fix all factors except for positions of shapes")
flags.DEFINE_bool("dsize_only", False, "fix all factors except for size of objects")
flags.DEFINE_bool("drot_only", False, "fix all factors except for rotation of objects")
flags.DEFINE_bool(
"dsprites_restrict", False, "fix all factors except for rotation of objects"
)
flags.DEFINE_string("imagenet_path", "/root/imagenet", "path to imagenet images")
# Data augmentation options # Data augmentation options
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image') flags.DEFINE_bool("cutout_inside", False, "whether cutoff should always in image")
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout') flags.DEFINE_float("cutout_prob", 1.0, "probability of using cutout")
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout') flags.DEFINE_integer("cutout_mask_size", 16, "size of cutout")
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data') flags.DEFINE_bool("cutout", False, "whether to add cutout regularizer to data")
def cutout(mask_color=(0, 0, 0)): def cutout(mask_color=(0, 0, 0)):
@ -91,13 +97,15 @@ class TFImagenetLoader(Dataset):
self.curr_sample = 0 self.curr_sample = 0
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json') index_path = osp.join(FLAGS.imagenet_datadir, "index.json")
with open(index_path) as f: with open(index_path) as f:
metadata = json.load(f) metadata = json.load(f)
counts = metadata['record_counts'] counts = metadata["record_counts"]
if split == 'train': if split == "train":
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')])) file_names = list(
sorted([x for x in counts.keys() if x.startswith("train")])
)
result_records_to_skip = None result_records_to_skip = None
files = [] files = []
@ -111,30 +119,44 @@ class TFImagenetLoader(Dataset):
# Record the number to skip in the first file # Record the number to skip in the first file
result_records_to_skip = records_to_skip result_records_to_skip = records_to_skip
files.append(filename) files.append(filename)
records_to_read -= (records_in_file - records_to_skip) records_to_read -= records_in_file - records_to_skip
records_to_skip = 0 records_to_skip = 0
else: else:
break break
else: else:
files = list(sorted([x for x in counts.keys() if x.startswith('validation')])) files = list(
sorted([x for x in counts.keys() if x.startswith("validation")])
)
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files] files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess preprocess_function = ImagenetPreprocessor(
128, dtype=tf.float32, train=False
).parse_and_preprocess
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string) ds = tf.data.TFRecordDataset.from_generator(
lambda: files, output_types=tf.string
)
ds = ds.apply(tf.data.TFRecordDataset) ds = ds.apply(tf.data.TFRecordDataset)
ds = ds.take(im_length) ds = ds.take(im_length)
ds = ds.prefetch(buffer_size=FLAGS.batch_size) ds = ds.prefetch(buffer_size=FLAGS.batch_size)
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000)) ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4)) ds = ds.apply(
batching.map_and_batch(
map_func=preprocess_function,
batch_size=FLAGS.batch_size,
num_parallel_batches=4,
)
)
ds = ds.prefetch(buffer_size=2) ds = ds.prefetch(buffer_size=2)
ds_iterator = ds.make_initializable_iterator() ds_iterator = ds.make_initializable_iterator()
labels, images = ds_iterator.get_next() labels, images = ds_iterator.get_next()
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0) self.images = tf.clip_by_value(
images / 256 + tf.random_uniform(tf.shape(images), 0, 1.0 / 256), 0.0, 1.0
)
self.labels = labels self.labels = labels
config = tf.ConfigProto(device_count = {'GPU': 0}) config = tf.ConfigProto(device_count={"GPU": 0})
sess = tf.Session(config=config) sess = tf.Session(config=config)
sess.run(ds_iterator.initializer) sess.run(ds_iterator.initializer)
@ -147,11 +169,17 @@ class TFImagenetLoader(Dataset):
sess = self.sess sess = self.sess
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)) im_corrupt = np.random.uniform(
0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)
)
label, im = sess.run([self.labels, self.images]) label, im = sess.run([self.labels, self.images])
im = im * self.rescale im = im * self.rescale
label = np.eye(1000)[label.squeeze() - 1] label = np.eye(1000)[label.squeeze() - 1]
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label) im, im_corrupt, label = (
torch.from_numpy(im),
torch.from_numpy(im_corrupt),
torch.from_numpy(label),
)
return im_corrupt, im, label return im_corrupt, im, label
def __iter__(self): def __iter__(self):
@ -160,6 +188,7 @@ class TFImagenetLoader(Dataset):
def __len__(self): def __len__(self):
return self.im_length return self.im_length
class CelebA(Dataset): class CelebA(Dataset):
def __init__(self): def __init__(self):
@ -180,25 +209,18 @@ class CelebA(Dataset):
im = imread(path) im = imread(path)
im = imresize(im, (32, 32)) im = imresize(im, (32, 32))
image_size = 32 image_size = 32
im = im / 255. im = im / 255.0
if FLAGS.datasource == 'default': if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random': elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform( im_corrupt = np.random.uniform(0, 1, size=(image_size, image_size, 3))
0, 1, size=(image_size, image_size, 3))
return im_corrupt, im, label return im_corrupt, im, label
class Cifar10(Dataset): class Cifar10(Dataset):
def __init__( def __init__(self, train=True, full=False, augment=False, noise=True, rescale=1.0):
self,
train=True,
full=False,
augment=False,
noise=True,
rescale=1.0):
if augment: if augment:
transform_list = [ transform_list = [
@ -215,16 +237,10 @@ class Cifar10(Dataset):
transform = transforms.ToTensor() transform = transforms.ToTensor()
self.full = full self.full = full
self.data = CIFAR10( self.data = CIFAR10(ROOT_DIR, transform=transform, train=train, download=True)
ROOT_DIR,
transform=transform,
train=train,
download=True)
self.test_data = CIFAR10( self.test_data = CIFAR10(
ROOT_DIR, ROOT_DIR, transform=transform, train=False, download=True
transform=transform, )
train=False,
download=True)
self.one_hot_map = np.eye(10) self.one_hot_map = np.eye(10)
self.noise = noise self.noise = noise
self.rescale = rescale self.rescale = rescale
@ -255,16 +271,18 @@ class Cifar10(Dataset):
im = im * 255 / 256 im = im * 255 / 256
if self.noise: if self.noise:
im = im * self.rescale + \ im = im * self.rescale + np.random.uniform(
np.random.uniform(0, self.rescale * 1 / 256., im.shape) 0, self.rescale * 1 / 256.0, im.shape
)
np.random.seed((index + int(time.time() * 1e7)) % 2**32) np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default': if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random': elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform( im_corrupt = np.random.uniform(
0.0, self.rescale, (image_size, image_size, 3)) 0.0, self.rescale, (image_size, image_size, 3)
)
return im_corrupt, im, label return im_corrupt, im, label
@ -287,10 +305,8 @@ class Cifar100(Dataset):
transform = transforms.ToTensor() transform = transforms.ToTensor()
self.data = CIFAR100( self.data = CIFAR100(
"/root/cifar100", "/root/cifar100", transform=transform, train=train, download=True
transform=transform, )
train=train,
download=True)
self.one_hot_map = np.eye(100) self.one_hot_map = np.eye(100)
def __len__(self): def __len__(self):
@ -308,11 +324,10 @@ class Cifar100(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32) np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default': if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random': elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform( im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label return im_corrupt, im, label
@ -340,11 +355,10 @@ class Svhn(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32) np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default': if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random': elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform( im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label return im_corrupt, im, label
@ -352,9 +366,8 @@ class Svhn(Dataset):
class Mnist(Dataset): class Mnist(Dataset):
def __init__(self, train=True, rescale=1.0): def __init__(self, train=True, rescale=1.0):
self.data = MNIST( self.data = MNIST(
"/root/mnist", "/root/mnist", transform=transforms.ToTensor(), download=True, train=train
transform=transforms.ToTensor(), )
download=True, train=train)
self.labels = np.eye(10) self.labels = np.eye(10)
self.rescale = rescale self.rescale = rescale
@ -367,13 +380,13 @@ class Mnist(Dataset):
im = im.squeeze() im = im.squeeze()
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28)) # im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
# im = im.numpy() / 2 + 0.2 # im = im.numpy() / 2 + 0.2
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28)) im = im.numpy() / 256 * 255 + np.random.uniform(0, 1.0 / 256, (28, 28))
im = im * self.rescale im = im * self.rescale
image_size = 28 image_size = 28
if FLAGS.datasource == 'default': if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size) im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random': elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0, self.rescale, (28, 28)) im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
return im_corrupt, im, label return im_corrupt, im, label
@ -381,54 +394,63 @@ class Mnist(Dataset):
class DSprites(Dataset): class DSprites(Dataset):
def __init__( def __init__(
self, self, cond_size=False, cond_shape=False, cond_pos=False, cond_rot=False
cond_size=False, ):
cond_shape=False,
cond_pos=False,
cond_rot=False):
dat = np.load(FLAGS.dsprites_path) dat = np.load(FLAGS.dsprites_path)
if FLAGS.dshape_only: if FLAGS.dshape_only:
l = dat['latents_values'] l = dat["latents_values"]
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 / mask = (
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39) (l[:, 4] == 16 / 31)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1)) & (l[:, 5] == 16 / 31)
self.label = np.tile(dat['latents_values'][mask], (10000, 1)) & (l[:, 2] == 0.5)
& (l[:, 3] == 30 * np.pi / 39)
)
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
self.label = self.label[:, 1:2] self.label = self.label[:, 1:2]
elif FLAGS.dpos_only: elif FLAGS.dpos_only:
l = dat['latents_values'] l = dat["latents_values"]
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39) # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 1] == 1) & ( mask = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5) self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
self.data = np.tile(dat['imgs'][mask], (100, 1, 1)) self.label = np.tile(dat["latents_values"][mask], (100, 1))
self.label = np.tile(dat['latents_values'][mask], (100, 1))
self.label = self.label[:, 4:] + 0.5 self.label = self.label[:, 4:] + 0.5
elif FLAGS.dsize_only: elif FLAGS.dsize_only:
l = dat['latents_values'] l = dat["latents_values"]
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39) # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 / mask = (
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1) (l[:, 3] == 30 * np.pi / 39)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1)) & (l[:, 4] == 16 / 31)
self.label = np.tile(dat['latents_values'][mask], (10000, 1)) & (l[:, 5] == 16 / 31)
self.label = (self.label[:, 2:3]) & (l[:, 1] == 1)
)
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
self.label = self.label[:, 2:3]
elif FLAGS.drot_only: elif FLAGS.drot_only:
l = dat['latents_values'] l = dat["latents_values"]
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 / mask = (
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1) (l[:, 2] == 0.5)
self.data = np.tile(dat['imgs'][mask], (100, 1, 1)) & (l[:, 4] == 16 / 31)
self.label = np.tile(dat['latents_values'][mask], (100, 1)) & (l[:, 5] == 16 / 31)
self.label = (self.label[:, 3:4]) & (l[:, 1] == 1)
)
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (100, 1))
self.label = self.label[:, 3:4]
self.label = np.concatenate( self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1) [np.cos(self.label), np.sin(self.label)], axis=1
)
elif FLAGS.dsprites_restrict: elif FLAGS.dsprites_restrict:
l = dat['latents_values'] l = dat["latents_values"]
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39) mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
self.data = dat['imgs'][mask] self.data = dat["imgs"][mask]
self.label = dat['latents_values'][mask] self.label = dat["latents_values"][mask]
else: else:
self.data = dat['imgs'] self.data = dat["imgs"]
self.label = dat['latents_values'] self.label = dat["latents_values"]
if cond_size: if cond_size:
self.label = self.label[:, 2:3] self.label = self.label[:, 2:3]
@ -439,7 +461,8 @@ class DSprites(Dataset):
elif cond_rot: elif cond_rot:
self.label = self.label[:, 3:4] self.label = self.label[:, 3:4]
self.label = np.concatenate( self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1) [np.cos(self.label), np.sin(self.label)], axis=1
)
else: else:
self.label = self.label[:, 1:2] self.label = self.label[:, 1:2]
@ -452,20 +475,20 @@ class DSprites(Dataset):
im = self.data[index] im = self.data[index]
image_size = 64 image_size = 64
if not ( if (
FLAGS.dpos_only or FLAGS.dsize_only) and ( not (FLAGS.dpos_only or FLAGS.dsize_only)
not FLAGS.cond_size) and ( and (not FLAGS.cond_size)
not FLAGS.cond_pos) and ( and (not FLAGS.cond_pos)
not FLAGS.cond_rot) and ( and (not FLAGS.cond_rot)
not FLAGS.drot_only): and (not FLAGS.drot_only)
label = self.identity[self.label[index].astype( ):
np.int32) - 1].squeeze() label = self.identity[self.label[index].astype(np.int32) - 1].squeeze()
else: else:
label = self.label[index] label = self.label[index]
if FLAGS.datasource == 'default': if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size) im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random': elif FLAGS.datasource == "random":
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size) im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
return im_corrupt, im, label return im_corrupt, im, label
@ -478,25 +501,20 @@ class Imagenet(Dataset):
for i in range(1, 11): for i in range(1, 11):
f = pickle.load( f = pickle.load(
open( open(
osp.join( osp.join(FLAGS.imagenet_path, "train_data_batch_{}".format(i)),
FLAGS.imagenet_path, "rb",
'train_data_batch_{}'.format(i)), )
'rb')) )
if i == 1: if i == 1:
labels = f['labels'] labels = f["labels"]
data = f['data'] data = f["data"]
else: else:
labels.extend(f['labels']) labels.extend(f["labels"])
data = np.vstack((data, f['data'])) data = np.vstack((data, f["data"]))
else: else:
f = pickle.load( f = pickle.load(open(osp.join(FLAGS.imagenet_path, "val_data"), "rb"))
open( labels = f["labels"]
osp.join( data = f["data"]
FLAGS.imagenet_path,
'val_data'),
'rb'))
labels = f['labels']
data = f['data']
self.labels = labels self.labels = labels
self.data = data self.data = data
@ -520,11 +538,10 @@ class Imagenet(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32) np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default': if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random': elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform( im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label return im_corrupt, im, label

View file

@ -1,68 +1,100 @@
import os
import os.path as osp
import numpy as np
import tensorflow as tf import tensorflow as tf
import math from custom_adam import AdamOptimizer
from tqdm import tqdm from models import DspritesNet
from hmc import hmc from scipy.misc import imsave
from tensorflow.python.platform import flags from tensorflow.python.platform import flags
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from models import DspritesNet from tqdm import tqdm
from utils import optimistic_restore, ReplayBuffer from utils import ReplayBuffer
import os.path as osp
import numpy as np
from rl_algs.logger import TensorBoardOutputFormat
from scipy.misc import imsave
import os
from custom_adam import AdamOptimizer
flags.DEFINE_integer('batch_size', 256, 'Size of inputs') flags.DEFINE_integer("batch_size", 256, "Size of inputs")
flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things') flags.DEFINE_integer("data_workers", 4, "Number of workers to do things")
flags.DEFINE_string('logdir', 'cachedir', 'directory for logging') flags.DEFINE_string("logdir", "cachedir", "directory for logging")
flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored') flags.DEFINE_string(
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.') "savedir", "cachedir", "location where log of experiments will be stored"
flags.DEFINE_float('step_lr', 500, 'size of gradient descent size') )
flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters') flags.DEFINE_integer(
flags.DEFINE_bool('cclass', True, 'not cclass') "num_filters",
flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons') 64,
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') "number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') )
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') flags.DEFINE_float("step_lr", 500, "size of gradient descent size")
flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results') flags.DEFINE_string(
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label') "dsprites_path",
flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.') "/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape') "path to dsprites characters",
flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape') )
flags.DEFINE_bool("cclass", True, "not cclass")
flags.DEFINE_bool("proj_cclass", False, "use for backwards compatibility reasons")
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
flags.DEFINE_bool("plot_curve", False, "Generate a curve of results")
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
flags.DEFINE_string(
"task",
"conceptcombine",
"conceptcombine, labeldiscover, gentest, genbaseline, etc.",
)
flags.DEFINE_bool("joint_shape", False, "whether to use pos_size or pos_shape")
flags.DEFINE_bool("joint_rot", False, "whether to use pos_size or pos_shape")
# Conditions on which models to use # Conditions on which models to use
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position') flags.DEFINE_bool("cond_pos", True, "whether to condition on position")
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation') flags.DEFINE_bool("cond_rot", True, "whether to condition on rotation")
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape') flags.DEFINE_bool("cond_shape", True, "whether to condition on shape")
flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale') flags.DEFINE_bool("cond_scale", True, "whether to condition on scale")
flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments') flags.DEFINE_string("exp_size", "dsprites_2018_cond_size", "name of experiments")
flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments') flags.DEFINE_string("exp_shape", "dsprites_2018_cond_shape", "name of experiments")
flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments') flags.DEFINE_string("exp_pos", "dsprites_2018_cond_pos_cert", "name of experiments")
flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments') flags.DEFINE_string("exp_rot", "dsprites_cond_rot_119_00", "name of experiments")
flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume') flags.DEFINE_integer("resume_size", 169000, "First iteration to resume")
flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume') flags.DEFINE_integer("resume_shape", 477000, "Second iteration to resume")
flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume') flags.DEFINE_integer("resume_pos", 8000, "Second iteration to resume")
flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume') flags.DEFINE_integer("resume_rot", 690000, "Second iteration to resume")
flags.DEFINE_integer('break_steps', 300, 'steps to break') flags.DEFINE_integer("break_steps", 300, "steps to break")
# Whether to train for gentest # Whether to train for gentest
flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions') flags.DEFINE_bool(
"train",
False,
"whether to train on generalization into multiple different predictions",
)
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class DSpritesGen(Dataset): class DSpritesGen(Dataset):
def __init__(self, data, latents, frac=0.0): def __init__(self, data, latents, frac=0.0):
l = latents l = latents
if FLAGS.joint_shape: if FLAGS.joint_shape:
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5) mask_size = (
(l[:, 3] == 30 * np.pi / 39)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 2] == 0.5)
)
elif FLAGS.joint_rot: elif FLAGS.joint_rot:
mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5) mask_size = (
(l[:, 1] == 1)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 2] == 0.5)
)
else: else:
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1) mask_size = (
(l[:, 3] == 30 * np.pi / 39)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 1] == 1)
)
mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5) mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
@ -80,12 +112,14 @@ class DSpritesGen(Dataset):
self.data = np.concatenate((data_pos, data_size), axis=0) self.data = np.concatenate((data_pos, data_size), axis=0)
self.label = np.concatenate((l_pos, l_size), axis=0) self.label = np.concatenate((l_pos, l_size), axis=0)
mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)) mask_neg = (~(mask_size & mask_pos)) & (
(l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)
)
data_add = data[mask_neg] data_add = data[mask_neg]
l_add = l[mask_neg] l_add = l[mask_neg]
perm_idx = np.random.permutation(data_add.shape[0]) perm_idx = np.random.permutation(data_add.shape[0])
select_idx = perm_idx[:int(frac*perm_idx.shape[0])] select_idx = perm_idx[: int(frac * perm_idx.shape[0])]
data_add = data_add[select_idx] data_add = data_add[select_idx]
l_add = l_add[select_idx] l_add = l_add[select_idx]
@ -104,7 +138,9 @@ class DSpritesGen(Dataset):
if FLAGS.joint_shape: if FLAGS.joint_shape:
label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1] label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
elif FLAGS.joint_rot: elif FLAGS.joint_rot:
label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]) label_size = np.array(
[np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]
)
else: else:
label_size = self.label[index, 2:3] label_size = self.label[index, 2:3]
@ -114,14 +150,16 @@ class DSpritesGen(Dataset):
def labeldiscover(sess, kvs, data, latents, save_exp_dir): def labeldiscover(sess, kvs, data, latents, save_exp_dir):
LABEL_SIZE = kvs['LABEL_SIZE'] LABEL_SIZE = kvs["LABEL_SIZE"]
model_size = kvs['model_size'] model_size = kvs["model_size"]
weight_size = kvs['weight_size'] weight_size = kvs["weight_size"]
x_mod = kvs['X_NOISE'] x_mod = kvs["X_NOISE"]
label_output = LABEL_SIZE label_output = LABEL_SIZE
for i in range(FLAGS.num_steps): for i in range(FLAGS.num_steps):
label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03) label_output = label_output + tf.random_normal(
tf.shape(label_output), mean=0.0, stddev=0.03
)
e_noise = model_size.forward(x_mod, weight_size, label=label_output) e_noise = model_size.forward(x_mod, weight_size, label=label_output)
label_grad = tf.gradients(e_noise, [label_output])[0] label_grad = tf.gradients(e_noise, [label_output])[0]
# label_grad = tf.Print(label_grad, [label_grad]) # label_grad = tf.Print(label_grad, [label_grad])
@ -130,13 +168,13 @@ def labeldiscover(sess, kvs, data, latents, save_exp_dir):
diffs = [] diffs = []
for i in range(30): for i in range(30):
s = i*FLAGS.batch_size s = i * FLAGS.batch_size
d = (i+1)*FLAGS.batch_size d = (i + 1) * FLAGS.batch_size
data_i = data[s:d] data_i = data[s:d]
latent_i = latents[s:d] latent_i = latents[s:d]
latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1)) latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init} feed_dict = {x_mod: data_i, LABEL_SIZE: latent_init}
size_pred = sess.run([label_output], feed_dict)[0] size_pred = sess.run([label_output], feed_dict)[0]
size_gt = latent_i[:, 2:3] size_gt = latent_i[:, 2:3]
@ -155,9 +193,11 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3) model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac)) weights_baseline = model_baseline.construct_weights(
"context_baseline_{}".format(frac)
)
X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32) X_feed = tf.placeholder(shape=(None, 2 * FLAGS.num_filters), dtype=tf.float32)
X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL) X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
@ -168,14 +208,20 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
gvs = [(k, v) for (k, v) in gvs if k is not None] gvs = [(k, v) for (k, v) in gvs if k is not None]
train_op = optimizer.apply_gradients(gvs) train_op = optimizer.apply_gradients(gvs)
dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) dataloader = DataLoader(
DSpritesGen(data, latents, frac=frac),
batch_size=FLAGS.batch_size,
num_workers=6,
drop_last=True,
shuffle=True,
)
datafull = data datafull = data
itr = 0 itr = 0
saver = tf.train.Saver() saver = tf.train.Saver()
vs = optimizer.variables() optimizer.variables()
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
if FLAGS.train: if FLAGS.train:
@ -185,7 +231,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
data_corrupt = data_corrupt.numpy() data_corrupt = data_corrupt.numpy()
label_size, label_pos = label_size.numpy(), label_pos.numpy() label_size, label_pos = label_size.numpy(), label_pos.numpy()
data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters) data_corrupt = np.random.randn(
data_corrupt.shape[0], 2 * FLAGS.num_filters
)
label_comb = np.concatenate([label_size, label_pos], axis=1) label_comb = np.concatenate([label_size, label_pos], axis=1)
feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb} feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
@ -196,23 +244,27 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
itr += 1 itr += 1
saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline')) saver.save(sess, osp.join(save_exp_dir, "model_genbaseline"))
saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline')) saver.restore(sess, osp.join(save_exp_dir, "model_genbaseline"))
l = latents l = latents
if FLAGS.joint_shape: if FLAGS.joint_shape:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
else: else:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
)
data_gen = datafull[mask_gen] data_gen = datafull[mask_gen]
latents_gen = latents[mask_gen] latents_gen = latents[mask_gen]
losses = [] losses = []
for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)): for dat, latent in zip(
data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters) np.array_split(data_gen, 10), np.array_split(latents_gen, 10)
):
data_init = np.random.randn(dat.shape[0], 2 * FLAGS.num_filters)
if FLAGS.joint_shape: if FLAGS.joint_shape:
latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1] latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
@ -220,16 +272,19 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
latent = np.concatenate([latent_size, latent_pos], axis=1) latent = np.concatenate([latent_size, latent_pos], axis=1)
feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat} feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
else: else:
feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat} feed_dict = {X_feed: data_init, LABEL: latent[:, [2, 4, 5]], X_label: dat}
loss = sess.run([loss_sq], feed_dict=feed_dict)[0] loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
# print(loss) # print(loss)
losses.append(loss) losses.append(loss)
print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac)) print(
"Overall MSE for generalization of {} for fraction of {}".format(
np.mean(losses), frac
)
)
data_try = data_gen[:10] data_try = data_gen[:10]
data_init = np.random.randn(10, 2*FLAGS.num_filters) data_init = np.random.randn(10, 2 * FLAGS.num_filters)
if FLAGS.joint_shape: if FLAGS.joint_shape:
latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1] latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
@ -252,7 +307,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
x_output_wrap[:, 1:-1, 1:-1] = x_output x_output_wrap[:, 1:-1, 1:-1] = x_output
data_try_wrap[:, 1:-1, 1:-1] = data_try data_try_wrap[:, 1:-1, 1:-1] = data_try
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
-1, 66 * 2
)
impath = osp.join(save_exp_dir, im_name) impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output) imsave(impath, im_output)
print("Successfully saved images at {}".format(impath)) print("Successfully saved images at {}".format(impath))
@ -261,38 +318,40 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
def gentest(sess, kvs, data, latents, save_exp_dir): def gentest(sess, kvs, data, latents, save_exp_dir):
X_NOISE = kvs['X_NOISE'] X_NOISE = kvs["X_NOISE"]
LABEL_SIZE = kvs['LABEL_SIZE'] LABEL_SIZE = kvs["LABEL_SIZE"]
LABEL_SHAPE = kvs['LABEL_SHAPE'] LABEL_SHAPE = kvs["LABEL_SHAPE"]
LABEL_POS = kvs['LABEL_POS'] LABEL_POS = kvs["LABEL_POS"]
LABEL_ROT = kvs['LABEL_ROT'] LABEL_ROT = kvs["LABEL_ROT"]
model_size = kvs['model_size'] model_size = kvs["model_size"]
model_shape = kvs['model_shape'] model_shape = kvs["model_shape"]
model_pos = kvs['model_pos'] model_pos = kvs["model_pos"]
model_rot = kvs['model_rot'] model_rot = kvs["model_rot"]
weight_size = kvs['weight_size'] weight_size = kvs["weight_size"]
weight_shape = kvs['weight_shape'] weight_shape = kvs["weight_shape"]
weight_pos = kvs['weight_pos'] weight_pos = kvs["weight_pos"]
weight_rot = kvs['weight_rot'] weight_rot = kvs["weight_rot"]
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
datafull = data datafull = data
# Test combination of generalization where we use slices of both training # Test combination of generalization where we use slices of both training
x_final = X_NOISE x_final = X_NOISE
x_mod_size = X_NOISE
x_mod_pos = X_NOISE x_mod_pos = X_NOISE
for i in range(FLAGS.num_steps): for i in range(FLAGS.num_steps):
# use cond_pos # use cond_pos
energies = [] x_mod_pos = x_mod_pos + tf.random_normal(
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) tf.shape(x_mod_pos), mean=0.0, stddev=0.005
)
e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS) e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
# energies.append(e_noise) # energies.append(e_noise)
x_grad = tf.gradients(e_noise, [x_final])[0] x_grad = tf.gradients(e_noise, [x_final])[0]
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) x_mod_pos = x_mod_pos + tf.random_normal(
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
)
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
@ -332,42 +391,70 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
x_mod = x_mod_pos x_mod = x_mod_pos
x_final = x_mod x_final = x_mod
if FLAGS.joint_shape: if FLAGS.joint_shape:
loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \ loss_kl = model_shape.forward(
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True
) + model_pos.forward(
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
)
energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \ energy_pos = model_shape.forward(
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) X, weight_shape, reuse=True, label=LABEL_SHAPE
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \ energy_neg = model_shape.forward(
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE
) + model_pos.forward(
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
)
elif FLAGS.joint_rot: elif FLAGS.joint_rot:
loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \ loss_kl = model_rot.forward(
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True
) + model_pos.forward(
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
)
energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \ energy_pos = model_rot.forward(
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) X, weight_rot, reuse=True, label=LABEL_ROT
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \ energy_neg = model_rot.forward(
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT
) + model_pos.forward(
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
)
else: else:
loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \ loss_kl = model_size.forward(
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True
) + model_pos.forward(
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
)
energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \ energy_pos = model_size.forward(
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) X, weight_size, reuse=True, label=LABEL_SIZE
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \ energy_neg = model_size.forward(
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE
) + model_pos.forward(
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
)
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg)
coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced)) coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
neg_loss = coeff * (-1*energy_neg) / norm_constant coeff * (-1 * energy_neg) / norm_constant
loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg) loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) loss_total = (
loss_ml
+ tf.reduce_mean(loss_kl)
+ 1
* (
tf.reduce_mean(tf.square(energy_pos))
+ tf.reduce_mean(tf.square(energy_neg))
)
)
optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999) optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
gvs = optimizer.compute_gradients(loss_total) gvs = optimizer.compute_gradients(loss_total)
@ -377,7 +464,13 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
vs = optimizer.variables() vs = optimizer.variables()
sess.run(tf.variables_initializer(vs)) sess.run(tf.variables_initializer(vs))
dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) dataloader = DataLoader(
DSpritesGen(data, latents),
batch_size=FLAGS.batch_size,
num_workers=6,
drop_last=True,
shuffle=True,
)
x_off = tf.reduce_mean(tf.square(x_mod - X)) x_off = tf.reduce_mean(tf.square(x_mod - X))
@ -385,12 +478,10 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
saver = tf.train.Saver() saver = tf.train.Saver()
x_mod = None x_mod = None
if FLAGS.train: if FLAGS.train:
replay_buffer = ReplayBuffer(10000) replay_buffer = ReplayBuffer(10000)
for _ in range(1): for _ in range(1):
for data_corrupt, data, label_size, label_pos in tqdm(dataloader): for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
data_corrupt = data_corrupt.numpy()[:, :, :] data_corrupt = data_corrupt.numpy()[:, :, :]
data = data.numpy()[:, :, :] data = data.numpy()[:, :, :]
@ -398,29 +489,50 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
if x_mod is not None: if x_mod is not None:
replay_buffer.add(x_mod) replay_buffer.add(x_mod)
replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = replay_buffer.sample(FLAGS.batch_size)
replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95) replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95
data_corrupt[replay_mask] = replay_batch[replay_mask] data_corrupt[replay_mask] = replay_batch[replay_mask]
if FLAGS.joint_shape: if FLAGS.joint_shape:
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos} feed_dict = {
X_NOISE: data_corrupt,
X: data,
LABEL_SHAPE: label_size,
LABEL_POS: label_pos,
}
elif FLAGS.joint_rot: elif FLAGS.joint_rot:
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos} feed_dict = {
X_NOISE: data_corrupt,
X: data,
LABEL_ROT: label_size,
LABEL_POS: label_pos,
}
else: else:
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos} feed_dict = {
X_NOISE: data_corrupt,
X: data,
LABEL_SIZE: label_size,
LABEL_POS: label_pos,
}
_, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict) _, off_value, e_pos, e_neg, x_mod = sess.run(
[train_op, x_off, energy_pos, energy_neg, x_final],
feed_dict=feed_dict,
)
itr += 1 itr += 1
if itr % 10 == 0: if itr % 10 == 0:
print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr)) print(
"x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(
off_value, e_pos.mean(), e_neg.mean(), itr
)
)
if itr == FLAGS.break_steps: if itr == FLAGS.break_steps:
break break
saver.save(sess, osp.join(save_exp_dir, "model_gentest"))
saver.save(sess, osp.join(save_exp_dir, 'model_gentest')) saver.restore(sess, osp.join(save_exp_dir, "model_gentest"))
saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))
l = latents l = latents
@ -429,22 +541,43 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
elif FLAGS.joint_rot: elif FLAGS.joint_rot:
mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5) mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
else: else:
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
)
data_gen = datafull[mask_gen] data_gen = datafull[mask_gen]
latents_gen = latents[mask_gen] latents_gen = latents[mask_gen]
losses = [] losses = []
for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)): for dat, latent in zip(
np.array_split(data_gen, 120), np.array_split(latents_gen, 120)
):
x = 0.5 + np.random.randn(*dat.shape) x = 0.5 + np.random.randn(*dat.shape)
if FLAGS.joint_shape: if FLAGS.joint_shape:
feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} feed_dict = {
LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1],
LABEL_POS: latent[:, 4:],
X_NOISE: x,
X: dat,
}
elif FLAGS.joint_rot: elif FLAGS.joint_rot:
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} feed_dict = {
LABEL_ROT: np.concatenate(
[np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1
),
LABEL_POS: latent[:, 4:],
X_NOISE: x,
X: dat,
}
else: else:
feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} feed_dict = {
LABEL_SIZE: latent[:, 2:3],
LABEL_POS: latent[:, 4:],
X_NOISE: x,
X: dat,
}
for i in range(2): for i in range(2):
x = sess.run([x_final], feed_dict=feed_dict)[0] x = sess.run([x_final], feed_dict=feed_dict)[0]
@ -461,11 +594,25 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
latent_pos = latents_gen[:10, 4:] latent_pos = latents_gen[:10, 4:]
if FLAGS.joint_shape: if FLAGS.joint_shape:
feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos} feed_dict = {
X_NOISE: data_init,
LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32) - 1],
LABEL_POS: latent_pos,
}
elif FLAGS.joint_rot: elif FLAGS.joint_rot:
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init} feed_dict = {
LABEL_ROT: np.concatenate(
[np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1
),
LABEL_POS: latent[:10, 4:],
X_NOISE: data_init,
}
else: else:
feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos} feed_dict = {
X_NOISE: data_init,
LABEL_SIZE: latent_scale,
LABEL_POS: latent_pos,
}
x_output = sess.run([x_final], feed_dict=feed_dict)[0] x_output = sess.run([x_final], feed_dict=feed_dict)[0]
@ -480,27 +627,28 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
x_output_wrap[:, 1:-1, 1:-1] = x_output x_output_wrap[:, 1:-1, 1:-1] = x_output
data_try_wrap[:, 1:-1, 1:-1] = data_try data_try_wrap[:, 1:-1, 1:-1] = data_try
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
-1, 66 * 2
)
impath = osp.join(save_exp_dir, im_name) impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output) imsave(impath, im_output)
print("Successfully saved images at {}".format(impath)) print("Successfully saved images at {}".format(impath))
def conceptcombine(sess, kvs, data, latents, save_exp_dir): def conceptcombine(sess, kvs, data, latents, save_exp_dir):
X_NOISE = kvs['X_NOISE'] X_NOISE = kvs["X_NOISE"]
LABEL_SIZE = kvs['LABEL_SIZE'] LABEL_SIZE = kvs["LABEL_SIZE"]
LABEL_SHAPE = kvs['LABEL_SHAPE'] LABEL_SHAPE = kvs["LABEL_SHAPE"]
LABEL_POS = kvs['LABEL_POS'] LABEL_POS = kvs["LABEL_POS"]
LABEL_ROT = kvs['LABEL_ROT'] LABEL_ROT = kvs["LABEL_ROT"]
model_size = kvs['model_size'] model_size = kvs["model_size"]
model_shape = kvs['model_shape'] model_shape = kvs["model_shape"]
model_pos = kvs['model_pos'] model_pos = kvs["model_pos"]
model_rot = kvs['model_rot'] model_rot = kvs["model_rot"]
weight_size = kvs['weight_size'] weight_size = kvs["weight_size"]
weight_shape = kvs['weight_shape'] weight_shape = kvs["weight_shape"]
weight_pos = kvs['weight_pos'] weight_pos = kvs["weight_pos"]
weight_rot = kvs['weight_rot'] weight_rot = kvs["weight_rot"]
x_mod = X_NOISE x_mod = X_NOISE
for i in range(FLAGS.num_steps): for i in range(FLAGS.num_steps):
@ -540,13 +688,18 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
data_try = data[:10] data_try = data[:10]
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
label_scale = latents[:10, 2:3] label_scale = latents[:10, 2:3]
label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)] label_shape = np.eye(3)[(latents[:10, 1] - 1).astype(np.uint8)]
label_rot = latents[:10, 3:4] label_rot = latents[:10, 3:4]
label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1) label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
label_pos = latents[:10, 4:] label_pos = latents[:10, 4:]
feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos, feed_dict = {
LABEL_ROT: label_rot} X_NOISE: data_init,
LABEL_SIZE: label_scale,
LABEL_SHAPE: label_shape,
LABEL_POS: label_pos,
LABEL_ROT: label_rot,
}
x_out = sess.run([x_final], feed_dict)[0] x_out = sess.run([x_final], feed_dict)[0]
im_name = "im" im_name = "im"
@ -569,14 +722,15 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
x_out_pad[:, 1:-1, 1:-1] = x_out x_out_pad[:, 1:-1, 1:-1] = x_out
data_try_pad[:, 1:-1, 1:-1] = data_try data_try_pad[:, 1:-1, 1:-1] = data_try
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2) im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66 * 2)
impath = osp.join(save_exp_dir, im_name) impath = osp.join(save_exp_dir, im_name)
imsave(impath, im_output) imsave(impath, im_output)
print("Successfully saved images at {}".format(impath)) print("Successfully saved images at {}".format(impath))
def main(): def main():
data = np.load(FLAGS.dsprites_path)['imgs'] data = np.load(FLAGS.dsprites_path)["imgs"]
l = latents = np.load(FLAGS.dsprites_path)['latents_values'] l = latents = np.load(FLAGS.dsprites_path)["latents_values"]
np.random.seed(1) np.random.seed(1)
idx = np.random.permutation(data.shape[0]) idx = np.random.permutation(data.shape[0])
@ -589,52 +743,74 @@ def main():
# Model 1 will be conditioned on size # Model 1 will be conditioned on size
model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True) model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
weight_size = model_size.construct_weights('context_0') weight_size = model_size.construct_weights("context_0")
# Model 2 will be conditioned on shape # Model 2 will be conditioned on shape
model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True) model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
weight_shape = model_shape.construct_weights('context_1') weight_shape = model_shape.construct_weights("context_1")
# Model 3 will be conditioned on position # Model 3 will be conditioned on position
model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True) model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
weight_pos = model_pos.construct_weights('context_2') weight_pos = model_pos.construct_weights("context_2")
# Model 4 will be conditioned on rotation # Model 4 will be conditioned on rotation
model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True) model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
weight_rot = model_rot.construct_weights('context_3') weight_rot = model_rot.construct_weights("context_3")
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size)) save_path_size = osp.join(
FLAGS.logdir, FLAGS.exp_size, "model_{}".format(FLAGS.resume_size)
)
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0)) v_list = tf.get_collection(
v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list} tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(0)
)
v_map = {
(v.name.replace("context_{}".format(0), "context_0")[:-2]): v for v in v_list
}
if FLAGS.cond_scale: if FLAGS.cond_scale:
saver = tf.train.Saver(v_map) saver = tf.train.Saver(v_map)
saver.restore(sess, save_path_size) saver.restore(sess, save_path_size)
save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape)) save_path_shape = osp.join(
FLAGS.logdir, FLAGS.exp_shape, "model_{}".format(FLAGS.resume_shape)
)
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1)) v_list = tf.get_collection(
v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list} tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(1)
)
v_map = {
(v.name.replace("context_{}".format(1), "context_0")[:-2]): v for v in v_list
}
if FLAGS.cond_shape: if FLAGS.cond_shape:
saver = tf.train.Saver(v_map) saver = tf.train.Saver(v_map)
saver.restore(sess, save_path_shape) saver.restore(sess, save_path_shape)
save_path_pos = osp.join(
save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos)) FLAGS.logdir, FLAGS.exp_pos, "model_{}".format(FLAGS.resume_pos)
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2)) )
v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list} v_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(2)
)
v_map = {
(v.name.replace("context_{}".format(2), "context_0")[:-2]): v for v in v_list
}
saver = tf.train.Saver(v_map) saver = tf.train.Saver(v_map)
if FLAGS.cond_pos: if FLAGS.cond_pos:
saver.restore(sess, save_path_pos) saver.restore(sess, save_path_pos)
save_path_rot = osp.join(
save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot)) FLAGS.logdir, FLAGS.exp_rot, "model_{}".format(FLAGS.resume_rot)
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3)) )
v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list} v_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(3)
)
v_map = {
(v.name.replace("context_{}".format(3), "context_0")[:-2]): v for v in v_list
}
saver = tf.train.Saver(v_map) saver = tf.train.Saver(v_map)
if FLAGS.cond_rot: if FLAGS.cond_rot:
@ -646,53 +822,57 @@ def main():
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
x_mod = X_NOISE
kvs = {} kvs = {}
kvs['X_NOISE'] = X_NOISE kvs["X_NOISE"] = X_NOISE
kvs['LABEL_SIZE'] = LABEL_SIZE kvs["LABEL_SIZE"] = LABEL_SIZE
kvs['LABEL_SHAPE'] = LABEL_SHAPE kvs["LABEL_SHAPE"] = LABEL_SHAPE
kvs['LABEL_POS'] = LABEL_POS kvs["LABEL_POS"] = LABEL_POS
kvs['LABEL_ROT'] = LABEL_ROT kvs["LABEL_ROT"] = LABEL_ROT
kvs['model_size'] = model_size kvs["model_size"] = model_size
kvs['model_shape'] = model_shape kvs["model_shape"] = model_shape
kvs['model_pos'] = model_pos kvs["model_pos"] = model_pos
kvs['model_rot'] = model_rot kvs["model_rot"] = model_rot
kvs['weight_size'] = weight_size kvs["weight_size"] = weight_size
kvs['weight_shape'] = weight_shape kvs["weight_shape"] = weight_shape
kvs['weight_pos'] = weight_pos kvs["weight_pos"] = weight_pos
kvs['weight_rot'] = weight_rot kvs["weight_rot"] = weight_rot
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape)) save_exp_dir = osp.join(
FLAGS.savedir, "{}_{}_joint".format(FLAGS.exp_size, FLAGS.exp_shape)
)
if not osp.exists(save_exp_dir): if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir) os.makedirs(save_exp_dir)
if FLAGS.task == "conceptcombine":
if FLAGS.task == 'conceptcombine':
conceptcombine(sess, kvs, data, latents, save_exp_dir) conceptcombine(sess, kvs, data, latents, save_exp_dir)
elif FLAGS.task == 'labeldiscover': elif FLAGS.task == "labeldiscover":
labeldiscover(sess, kvs, data, latents, save_exp_dir) labeldiscover(sess, kvs, data, latents, save_exp_dir)
elif FLAGS.task == 'gentest': elif FLAGS.task == "gentest":
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos)) save_exp_dir = osp.join(
FLAGS.savedir, "{}_{}_gen".format(FLAGS.exp_size, FLAGS.exp_pos)
)
if not osp.exists(save_exp_dir): if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir) os.makedirs(save_exp_dir)
gentest(sess, kvs, data, latents, save_exp_dir) gentest(sess, kvs, data, latents, save_exp_dir)
elif FLAGS.task == 'genbaseline': elif FLAGS.task == "genbaseline":
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos)) save_exp_dir = osp.join(
FLAGS.savedir, "{}_{}_gen_baseline".format(FLAGS.exp_size, FLAGS.exp_pos)
)
if not osp.exists(save_exp_dir): if not osp.exists(save_exp_dir):
os.makedirs(save_exp_dir) os.makedirs(save_exp_dir)
if FLAGS.plot_curve: if FLAGS.plot_curve:
mse_losses = [] mse_losses = []
for frac in [i/10 for i in range(11)]: for frac in [i / 10 for i in range(11)]:
mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac) mse_loss = genbaseline(
sess, kvs, data, latents, save_exp_dir, frac=frac
)
mse_losses.append(mse_loss) mse_losses.append(mse_loss)
np.save("mse_baseline_comb.npy", mse_losses) np.save("mse_baseline_comb.npy", mse_losses)
else: else:
genbaseline(sess, kvs, data, latents, save_exp_dir) genbaseline(sess, kvs, data, latents, save_exp_dir)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs. """Calculates the Frechet Inception Distance (FID) to evalulate GANs.
The FID metric calculates the distance between two distributions of images. The FID metric calculates the distance between two distributions of images.
Typically, we have summary statistics (mean & covariance matrix) of one Typically, we have summary statistics (mean & covariance matrix) of one
@ -14,28 +14,33 @@ the pool_3 layer of the inception net for generated samples and real world
samples respectivly. samples respectivly.
See --help to see further details. See --help to see further details.
''' """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import numpy as np
import os import os
import gzip, pickle
import tensorflow as tf
from scipy.misc import imread
from scipy import linalg
import pathlib import pathlib
import urllib
import tarfile import tarfile
import urllib
import warnings import warnings
MODEL_DIR = '/tmp/imagenet' import numpy as np
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' import tensorflow as tf
from scipy import linalg
from scipy.misc import imread
MODEL_DIR = "/tmp/imagenet"
DATA_URL = (
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
)
pool3 = None pool3 = None
class InvalidFIDException(Exception): class InvalidFIDException(Exception):
pass pass
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def get_fid_score(images, images_gt): def get_fid_score(images, images_gt):
images = np.stack(images, 0) images = np.stack(images, 0)
images_gt = np.stack(images_gt, 0) images_gt = np.stack(images_gt, 0)
@ -52,34 +57,38 @@ def get_fid_score(images, images_gt):
def create_inception_graph(pth): def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file.""" """Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb. # Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile( pth, 'rb') as f: with tf.gfile.FastGFile(pth, "rb") as f:
graph_def = tf.GraphDef() graph_def = tf.GraphDef()
graph_def.ParseFromString( f.read()) graph_def.ParseFromString(f.read())
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net') _ = tf.import_graph_def(graph_def, name="FID_Inception_Net")
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# code for handling inception net derived from # code for handling inception net derived from
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py # https://github.com/openai/improved-gan/blob/master/inception_score/model.py
def _get_inception_layer(sess): def _get_inception_layer(sess):
"""Prepares inception net for batched usage and returns pool_3 layer. """ """Prepares inception net for batched usage and returns pool_3 layer."""
layername = 'FID_Inception_Net/pool_3:0' layername = "FID_Inception_Net/pool_3:0"
pool3 = sess.graph.get_tensor_by_name(layername) pool3 = sess.graph.get_tensor_by_name(layername)
ops = pool3.graph.get_operations() ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops): for op_idx, op in enumerate(ops):
for o in op.outputs: for o in op.outputs:
shape = o.get_shape() shape = o.get_shape()
if shape._dims != []: if shape._dims != []:
shape = [s.value for s in shape] shape = [s.value for s in shape]
new_shape = [] new_shape = []
for j, s in enumerate(shape): for j, s in enumerate(shape):
if s == 1 and j == 0: if s == 1 and j == 0:
new_shape.append(None) new_shape.append(None)
else: else:
new_shape.append(s) new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape) o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3 return pool3
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def get_activations(images, sess, batch_size=50, verbose=False): def get_activations(images, sess, batch_size=50, verbose=False):
@ -100,23 +109,27 @@ def get_activations(images, sess, batch_size=50, verbose=False):
# inception_layer = _get_inception_layer(sess) # inception_layer = _get_inception_layer(sess)
d0 = images.shape[0] d0 = images.shape[0]
if batch_size > d0: if batch_size > d0:
print("warning: batch size is bigger than the data size. setting batch size to data size") print(
"warning: batch size is bigger than the data size. setting batch size to data size"
)
batch_size = d0 batch_size = d0
n_batches = d0//batch_size n_batches = d0 // batch_size
n_used_imgs = n_batches*batch_size n_used_imgs = n_batches * batch_size
pred_arr = np.empty((n_used_imgs,2048)) pred_arr = np.empty((n_used_imgs, 2048))
for i in range(n_batches): for i in range(n_batches):
if verbose: if verbose:
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True)
start = i*batch_size start = i * batch_size
end = start + batch_size end = start + batch_size
batch = images[start:end] batch = images[start:end]
pred = sess.run(pool3, {'ExpandDims:0': batch}) pred = sess.run(pool3, {"ExpandDims:0": batch})
pred_arr[start:end] = pred.reshape(batch_size,-1) pred_arr[start:end] = pred.reshape(batch_size, -1)
if verbose: if verbose:
print(" done") print(" done")
return pred_arr return pred_arr
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
@ -147,15 +160,22 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
sigma1 = np.atleast_2d(sigma1) sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2) sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" assert (
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
diff = mu1 - mu2 diff = mu1 - mu2
# product might be almost singular # product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all(): if not np.isfinite(covmean).all():
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg) warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
@ -170,7 +190,9 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
tr_covmean = np.trace(covmean) tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
@ -193,47 +215,52 @@ def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
mu = np.mean(act, axis=0) mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False) sigma = np.cov(act, rowvar=False)
return mu, sigma return mu, sigma
#-------------------------------------------------------------------------------
#------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# The following functions aren't needed for calculating the FID # The following functions aren't needed for calculating the FID
# they're just here to make this module work as a stand-alone script # they're just here to make this module work as a stand-alone script
# for calculating FID scores # for calculating FID scores
#------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
def check_or_download_inception(inception_path): def check_or_download_inception(inception_path):
''' Checks if the path to the inception file is valid, or downloads """Checks if the path to the inception file is valid, or downloads
the file if it is not present. ''' the file if it is not present."""
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' INCEPTION_URL = (
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
)
if inception_path is None: if inception_path is None:
inception_path = '/tmp' inception_path = "/tmp"
inception_path = pathlib.Path(inception_path) inception_path = pathlib.Path(inception_path)
model_file = inception_path / 'classify_image_graph_def.pb' model_file = inception_path / "classify_image_graph_def.pb"
if not model_file.exists(): if not model_file.exists():
print("Downloading Inception model") print("Downloading Inception model")
from urllib import request
import tarfile import tarfile
from urllib import request
fn, _ = request.urlretrieve(INCEPTION_URL) fn, _ = request.urlretrieve(INCEPTION_URL)
with tarfile.open(fn, mode='r') as f: with tarfile.open(fn, mode="r") as f:
f.extract('classify_image_graph_def.pb', str(model_file.parent)) f.extract("classify_image_graph_def.pb", str(model_file.parent))
return str(model_file) return str(model_file)
def _handle_path(path, sess): def _handle_path(path, sess):
if path.endswith('.npz'): if path.endswith(".npz"):
f = np.load(path) f = np.load(path)
m, s = f['mu'][:], f['sigma'][:] m, s = f["mu"][:], f["sigma"][:]
f.close() f.close()
else: else:
path = pathlib.Path(path) path = pathlib.Path(path)
files = list(path.glob('*.jpg')) + list(path.glob('*.png')) files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
m, s = calculate_activation_statistics(x, sess) m, s = calculate_activation_statistics(x, sess)
return m, s return m, s
def calculate_fid_given_paths(paths, inception_path): def calculate_fid_given_paths(paths, inception_path):
''' Calculates the FID of two paths. ''' """Calculates the FID of two paths."""
inception_path = check_or_download_inception(inception_path) inception_path = check_or_download_inception(inception_path)
for p in paths: for p in paths:
@ -250,43 +277,48 @@ def calculate_fid_given_paths(paths, inception_path):
def _init_inception(): def _init_inception():
global pool3 global pool3
if not os.path.exists(MODEL_DIR): if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR) os.makedirs(MODEL_DIR)
filename = DATA_URL.split('/')[-1] filename = DATA_URL.split("/")[-1]
filepath = os.path.join(MODEL_DIR, filename) filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath): if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % ( def _progress(count, block_size, total_size):
filename, float(count * block_size) / float(total_size) * 100.0)) sys.stdout.write(
sys.stdout.flush() "\r>> Downloading %s %.1f%%"
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) % (filename, float(count * block_size) / float(total_size) * 100.0)
print() )
statinfo = os.stat(filepath) sys.stdout.flush()
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
with tf.gfile.FastGFile(os.path.join( print()
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: statinfo = os.stat(filepath)
graph_def = tf.GraphDef() print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
graph_def.ParseFromString(f.read()) tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
_ = tf.import_graph_def(graph_def, name='') with tf.gfile.FastGFile(
# Works with an arbitrary minibatch size. os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
with tf.Session() as sess: ) as f:
pool3 = sess.graph.get_tensor_by_name('pool_3:0') graph_def = tf.GraphDef()
ops = pool3.graph.get_operations() graph_def.ParseFromString(f.read())
for op_idx, op in enumerate(ops): _ = tf.import_graph_def(graph_def, name="")
for o in op.outputs: # Works with an arbitrary minibatch size.
shape = o.get_shape() with tf.Session() as sess:
if shape._dims != []: pool3 = sess.graph.get_tensor_by_name("pool_3:0")
shape = [s.value for s in shape] ops = pool3.graph.get_operations()
new_shape = [] for op_idx, op in enumerate(ops):
for j, s in enumerate(shape): for o in op.outputs:
if s == 1 and j == 0: shape = o.get_shape()
new_shape.append(None) if shape._dims != []:
else: shape = [s.value for s in shape]
new_shape.append(s) new_shape = []
o.__dict__['_shape_val'] = tf.TensorShape(new_shape) for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
if pool3 is None: if pool3 is None:
_init_inception() _init_inception()

View file

@ -1,11 +1,11 @@
import tensorflow as tf import tensorflow as tf
import numpy as np
from tensorflow.python.platform import flags from tensorflow.python.platform import flags
flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes')
flags.DEFINE_bool("proposal_debug", False, "Print hmc acceptance raes")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def kinetic_energy(velocity): def kinetic_energy(velocity):
"""Kinetic energy of the current velocity (assuming a standard Gaussian) """Kinetic energy of the current velocity (assuming a standard Gaussian)
(x dot x) / 2 (x dot x) / 2
@ -21,6 +21,7 @@ def kinetic_energy(velocity):
""" """
return 0.5 * tf.square(velocity) return 0.5 * tf.square(velocity)
def hamiltonian(position, velocity, energy_function): def hamiltonian(position, velocity, energy_function):
"""Computes the Hamiltonian of the current position, velocity pair """Computes the Hamiltonian of the current position, velocity pair
@ -44,13 +45,12 @@ def hamiltonian(position, velocity, energy_function):
""" """
batch_size = tf.shape(velocity)[0] batch_size = tf.shape(velocity)[0]
kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1)) kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1))
return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1]) return tf.squeeze(energy_function(position)) + tf.reduce_sum(
kinetic_energy_flat, axis=[1]
)
def leapfrog_step(x0,
v0, def leapfrog_step(x0, v0, neg_log_posterior, step_size, num_steps):
neg_log_posterior,
step_size,
num_steps):
# Start by updating the velocity a half-step # Start by updating the velocity a half-step
v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0] v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0]
@ -83,10 +83,8 @@ def leapfrog_step(x0,
# return new proposal state # return new proposal state
return x, v return x, v
def hmc(initial_x,
step_size, def hmc(initial_x, step_size, num_steps, neg_log_posterior):
num_steps,
neg_log_posterior):
"""Summary """Summary
Parameters Parameters
@ -107,11 +105,13 @@ def hmc(initial_x,
""" """
v0 = tf.random_normal(tf.shape(initial_x)) v0 = tf.random_normal(tf.shape(initial_x))
x, v = leapfrog_step(initial_x, x, v = leapfrog_step(
v0, initial_x,
step_size=step_size, v0,
num_steps=num_steps, step_size=step_size,
neg_log_posterior=neg_log_posterior) num_steps=num_steps,
neg_log_posterior=neg_log_posterior,
)
orig = hamiltonian(initial_x, v0, neg_log_posterior) orig = hamiltonian(initial_x, v0, neg_log_posterior)
current = hamiltonian(x, v, neg_log_posterior) current = hamiltonian(x, v, neg_log_posterior)
@ -119,10 +119,12 @@ def hmc(initial_x,
prob_accept = tf.exp(orig - current) prob_accept = tf.exp(orig - current)
if FLAGS.proposal_debug: if FLAGS.proposal_debug:
prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))]) prob_accept = tf.Print(
prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))]
)
uniform = tf.random_uniform(tf.shape(prob_accept)) uniform = tf.random_uniform(tf.shape(prob_accept))
keep_mask = (prob_accept > uniform) keep_mask = prob_accept > uniform
# print(keep_mask.get_shape()) # print(keep_mask.get_shape())
x_new = tf.where(keep_mask, x, initial_x) x_new = tf.where(keep_mask, x, initial_x)

View file

@ -1,24 +1,28 @@
from models import ResNet128
import numpy as np
import os.path as osp import os.path as osp
from tensorflow.python.platform import flags
import tensorflow as tf
import imageio import imageio
from utils import optimistic_restore import numpy as np
import tensorflow as tf
from models import ResNet128
from tensorflow.python.platform import flags
flags.DEFINE_string(
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') "logdir", "cachedir", "location where log of experiments will be stored"
flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling') )
flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics') flags.DEFINE_integer("num_steps", 200, "num of steps for conditional imagenet sampling")
flags.DEFINE_integer('batch_size', 16, 'number of steps to run') flags.DEFINE_float("step_lr", 180.0, "step size for Langevin dynamics")
flags.DEFINE_string('exp', 'default', 'name of experiments') flags.DEFINE_integer("batch_size", 16, "number of steps to run")
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from') flags.DEFINE_string("exp", "default", "name of experiments")
flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model') flags.DEFINE_integer("resume_iter", -1, "iteration to resume training from")
flags.DEFINE_bool('cclass', True, 'conditional models') flags.DEFINE_bool(
flags.DEFINE_bool('use_attention', False, 'using attention') "spec_norm", True, "whether to use spectral normalization in weights in a model"
)
flags.DEFINE_bool("cclass", True, "conditional models")
flags.DEFINE_bool("use_attention", False, "using attention")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def rescale_im(im): def rescale_im(im):
return np.clip(im * 256, 0, 255).astype(np.uint8) return np.clip(im * 256, 0, 255).astype(np.uint8)
@ -32,12 +36,11 @@ if __name__ == "__main__":
weights = model.construct_weights("context_0") weights = model.construct_weights("context_0")
x_mod = X_NOISE x_mod = X_NOISE
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
mean=0.0,
stddev=0.005)
energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL, energy_noise = energy_start = model.forward(
reuse=True, stop_at_grad=False, stop_batch=True) x_mod, weights, label=LABEL, reuse=True, stop_at_grad=False, stop_batch=True
)
x_grad = tf.gradients(energy_noise, [x_mod])[0] x_grad = tf.gradients(energy_noise, [x_mod])[0]
energy_noise_old = energy_noise energy_noise_old = energy_noise
@ -54,7 +57,7 @@ if __name__ == "__main__":
saver = loader = tf.train.Saver() saver = loader = tf.train.Saver()
logdir = osp.join(FLAGS.logdir, FLAGS.exp) logdir = osp.join(FLAGS.logdir, FLAGS.exp)
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
saver.restore(sess, model_file) saver.restore(sess, model_file)
lx = np.random.permutation(1000)[:16] lx = np.random.permutation(1000)[:16]
@ -65,9 +68,12 @@ if __name__ == "__main__":
labels = np.eye(1000)[lx] labels = np.eye(1000)[lx]
for i in range(FLAGS.num_steps): for i in range(FLAGS.num_steps):
e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels}) e, x_mod = sess.run([energy_noise, x_output], {X_NOISE: x_mod, LABEL: labels})
ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3))) ims.append(
rescale_im(x_mod)
imageio.mimwrite('sample.gif', ims) .reshape((4, 4, 128, 128, 3))
.transpose((0, 2, 1, 3, 4))
.reshape((512, 512, 3))
)
imageio.mimwrite("sample.gif", ims)

View file

@ -13,14 +13,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Image pre-processing utilities. """Image pre-processing utilities."""
"""
import tensorflow as tf import tensorflow as tf
IMAGE_DEPTH = 3 # color images
IMAGE_DEPTH = 3 # color images
import tensorflow as tf
# _R_MEAN = 123.68 # _R_MEAN = 123.68
# _G_MEAN = 116.78 # _G_MEAN = 116.78
@ -35,303 +32,318 @@ _RESIZE_MIN = 128
def _decode_crop_and_flip(image_buffer, bbox, num_channels): def _decode_crop_and_flip(image_buffer, bbox, num_channels):
"""Crops the given image to a random part of the image, and randomly flips. """Crops the given image to a random part of the image, and randomly flips.
We use the fused decode_and_crop op, which performs better than the two ops We use the fused decode_and_crop op, which performs better than the two ops
used separately in series, but note that this requires that the image be used separately in series, but note that this requires that the image be
passed in as an un-decoded string Tensor. passed in as an un-decoded string Tensor.
Args: Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer. image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax]. [ymin, xmin, ymax, xmax].
num_channels: Integer depth of the image buffer for decoding. num_channels: Integer depth of the image buffer for decoding.
Returns: Returns:
3-D tensor with cropped image. 3-D tensor with cropped image.
""" """
# A large fraction of image datasets contain a human-annotated bounding box # A large fraction of image datasets contain a human-annotated bounding box
# delineating the region of the image containing the object of interest. We # delineating the region of the image containing the object of interest. We
# choose to create a new bounding box for the object which is a randomly # choose to create a new bounding box for the object which is a randomly
# distorted version of the human-annotated bounding box that obeys an # distorted version of the human-annotated bounding box that obeys an
# allowed range of aspect ratios, sizes and overlap with the human-annotated # allowed range of aspect ratios, sizes and overlap with the human-annotated
# bounding box. If no box is supplied, then we assume the bounding box is # bounding box. If no box is supplied, then we assume the bounding box is
# the entire image. # the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.image.extract_jpeg_shape(image_buffer), tf.image.extract_jpeg_shape(image_buffer),
bounding_boxes=bbox, bounding_boxes=bbox,
min_object_covered=0.1, min_object_covered=0.1,
aspect_ratio_range=[0.75, 1.33], aspect_ratio_range=[0.75, 1.33],
area_range=[0.05, 1.0], area_range=[0.05, 1.0],
max_attempts=100, max_attempts=100,
use_image_if_no_bounding_boxes=True) use_image_if_no_bounding_boxes=True,
bbox_begin, bbox_size, _ = sample_distorted_bounding_box )
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Reassemble the bounding box in the format the crop op requires. # Reassemble the bounding box in the format the crop op requires.
offset_y, offset_x, _ = tf.unstack(bbox_begin) offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size) target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
# Use the fused decode and crop op here, which is faster than each in series. # Use the fused decode and crop op here, which is faster than each in
cropped = tf.image.decode_and_crop_jpeg( # series.
image_buffer, crop_window, channels=num_channels) cropped = tf.image.decode_and_crop_jpeg(
image_buffer, crop_window, channels=num_channels
)
# Flip to add a little more random distortion in. # Flip to add a little more random distortion in.
cropped = tf.image.random_flip_left_right(cropped) cropped = tf.image.random_flip_left_right(cropped)
return cropped return cropped
def _central_crop(image, crop_height, crop_width): def _central_crop(image, crop_height, crop_width):
"""Performs central crops of the given image list. """Performs central crops of the given image list.
Args: Args:
image: a 3-D image tensor image: a 3-D image tensor
crop_height: the height of the image following the crop. crop_height: the height of the image following the crop.
crop_width: the width of the image following the crop. crop_width: the width of the image following the crop.
Returns: Returns:
3-D tensor with cropped image. 3-D tensor with cropped image.
""" """
shape = tf.shape(input=image) shape = tf.shape(input=image)
height, width = shape[0], shape[1] height, width = shape[0], shape[1]
amount_to_be_cropped_h = (height - crop_height) amount_to_be_cropped_h = height - crop_height
crop_top = amount_to_be_cropped_h // 2 crop_top = amount_to_be_cropped_h // 2
amount_to_be_cropped_w = (width - crop_width) amount_to_be_cropped_w = width - crop_width
crop_left = amount_to_be_cropped_w // 2 crop_left = amount_to_be_cropped_w // 2
return tf.slice( return tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
def _mean_image_subtraction(image, means, num_channels): def _mean_image_subtraction(image, means, num_channels):
"""Subtracts the given means from each image channel. """Subtracts the given means from each image channel.
For example: For example:
means = [123.68, 116.779, 103.939] means = [123.68, 116.779, 103.939]
image = _mean_image_subtraction(image, means) image = _mean_image_subtraction(image, means)
Note that the rank of `image` must be known. Note that the rank of `image` must be known.
Args: Args:
image: a tensor of size [height, width, C]. image: a tensor of size [height, width, C].
means: a C-vector of values to subtract from each channel. means: a C-vector of values to subtract from each channel.
num_channels: number of color channels in the image that will be distorted. num_channels: number of color channels in the image that will be distorted.
Returns: Returns:
the centered image. the centered image.
Raises: Raises:
ValueError: If the rank of `image` is unknown, if `image` has a rank other ValueError: If the rank of `image` is unknown, if `image` has a rank other
than three or if the number of channels in `image` doesn't match the than three or if the number of channels in `image` doesn't match the
number of values in `means`. number of values in `means`.
""" """
if image.get_shape().ndims != 3: if image.get_shape().ndims != 3:
raise ValueError('Input must be of size [height, width, C>0]') raise ValueError("Input must be of size [height, width, C>0]")
if len(means) != num_channels: if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels') raise ValueError("len(means) must match the number of channels")
# We have a 1-D tensor of means; convert to 3-D. # We have a 1-D tensor of means; convert to 3-D.
means = tf.expand_dims(tf.expand_dims(means, 0), 0) means = tf.expand_dims(tf.expand_dims(means, 0), 0)
return image - means return image - means
def _smallest_size_at_least(height, width, resize_min): def _smallest_size_at_least(height, width, resize_min):
"""Computes new shape with the smallest side equal to `smallest_side`. """Computes new shape with the smallest side equal to `smallest_side`.
Computes new shape with the smallest side equal to `smallest_side` while Computes new shape with the smallest side equal to `smallest_side` while
preserving the original aspect ratio. preserving the original aspect ratio.
Args: Args:
height: an int32 scalar tensor indicating the current height. height: an int32 scalar tensor indicating the current height.
width: an int32 scalar tensor indicating the current width. width: an int32 scalar tensor indicating the current width.
resize_min: A python integer or scalar `Tensor` indicating the size of resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize. the smallest side after resize.
Returns: Returns:
new_height: an int32 scalar tensor indicating the new height. new_height: an int32 scalar tensor indicating the new height.
new_width: an int32 scalar tensor indicating the new width. new_width: an int32 scalar tensor indicating the new width.
""" """
resize_min = tf.cast(resize_min, tf.float32) resize_min = tf.cast(resize_min, tf.float32)
# Convert to floats to make subsequent calculations go smoothly. # Convert to floats to make subsequent calculations go smoothly.
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32) height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
smaller_dim = tf.minimum(height, width) smaller_dim = tf.minimum(height, width)
scale_ratio = resize_min / smaller_dim scale_ratio = resize_min / smaller_dim
# Convert back to ints to make heights and widths that TF ops will accept. # Convert back to ints to make heights and widths that TF ops will accept.
new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32) new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32) new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
return new_height, new_width return new_height, new_width
def _aspect_preserving_resize(image, resize_min): def _aspect_preserving_resize(image, resize_min):
"""Resize images preserving the original aspect ratio. """Resize images preserving the original aspect ratio.
Args: Args:
image: A 3-D image `Tensor`. image: A 3-D image `Tensor`.
resize_min: A python integer or scalar `Tensor` indicating the size of resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize. the smallest side after resize.
Returns: Returns:
resized_image: A 3-D tensor containing the resized image. resized_image: A 3-D tensor containing the resized image.
""" """
shape = tf.shape(input=image) shape = tf.shape(input=image)
height, width = shape[0], shape[1] height, width = shape[0], shape[1]
new_height, new_width = _smallest_size_at_least(height, width, resize_min) new_height, new_width = _smallest_size_at_least(height, width, resize_min)
return _resize_image(image, new_height, new_width) return _resize_image(image, new_height, new_width)
def _resize_image(image, height, width): def _resize_image(image, height, width):
"""Simple wrapper around tf.resize_images. """Simple wrapper around tf.resize_images.
This is primarily to make sure we use the same `ResizeMethod` and other This is primarily to make sure we use the same `ResizeMethod` and other
details each time. details each time.
Args: Args:
image: A 3-D image `Tensor`. image: A 3-D image `Tensor`.
height: The target height for the resized image. height: The target height for the resized image.
width: The target width for the resized image. width: The target width for the resized image.
Returns: Returns:
resized_image: A 3-D tensor containing the resized image. The first two resized_image: A 3-D tensor containing the resized image. The first two
dimensions have the shape [height, width]. dimensions have the shape [height, width].
""" """
return tf.image.resize_images( return tf.image.resize_images(
image, [height, width], method=tf.image.ResizeMethod.BILINEAR, image,
align_corners=False) [height, width],
method=tf.image.ResizeMethod.BILINEAR,
align_corners=False,
)
def preprocess_image(image_buffer, bbox, output_height, output_width, def preprocess_image(
num_channels, is_training=False): image_buffer, bbox, output_height, output_width, num_channels, is_training=False
"""Preprocesses the given image. ):
"""Preprocesses the given image.
Preprocessing includes decoding, cropping, and resizing for both training Preprocessing includes decoding, cropping, and resizing for both training
and eval images. Training preprocessing, however, introduces some random and eval images. Training preprocessing, however, introduces some random
distortion of the image to improve accuracy. distortion of the image to improve accuracy.
Args: Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer. image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax]. [ymin, xmin, ymax, xmax].
output_height: The height of the image after preprocessing. output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing. output_width: The width of the image after preprocessing.
num_channels: Integer depth of the image buffer for decoding. num_channels: Integer depth of the image buffer for decoding.
is_training: `True` if we're preprocessing the image for training and is_training: `True` if we're preprocessing the image for training and
`False` otherwise. `False` otherwise.
Returns: Returns:
A preprocessed image. A preprocessed image.
""" """
if is_training: if is_training:
# For training, we want to randomize some of the distortions. # For training, we want to randomize some of the distortions.
image = _decode_crop_and_flip(image_buffer, bbox, num_channels) image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
image = _resize_image(image, output_height, output_width) image = _resize_image(image, output_height, output_width)
else: else:
# For validation, we want to decode, resize, then just crop the middle. # For validation, we want to decode, resize, then just crop the middle.
image = tf.image.decode_jpeg(image_buffer, channels=num_channels) image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
image = _aspect_preserving_resize(image, _RESIZE_MIN) image = _aspect_preserving_resize(image, _RESIZE_MIN)
print(image) print(image)
image = _central_crop(image, output_height, output_width) image = _central_crop(image, output_height, output_width)
image.set_shape([output_height, output_width, num_channels]) image.set_shape([output_height, output_width, num_channels])
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels) return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
def parse_example_proto(example_serialized): def parse_example_proto(example_serialized):
"""Parses an Example proto containing a training example of an image. """Parses an Example proto containing a training example of an image.
The output of the build_image_data.py image preprocessing script is a dataset The output of the build_image_data.py image preprocessing script is a dataset
containing serialized Example protocol buffers. Each Example proto contains containing serialized Example protocol buffers. Each Example proto contains
the following fields: the following fields:
image/height: 462 image/height: 462
image/width: 581 image/width: 581
image/colorspace: 'RGB' image/colorspace: 'RGB'
image/channels: 3 image/channels: 3
image/class/label: 615 image/class/label: 615
image/class/synset: 'n03623198' image/class/synset: 'n03623198'
image/class/text: 'knee pad' image/class/text: 'knee pad'
image/object/bbox/xmin: 0.1 image/object/bbox/xmin: 0.1
image/object/bbox/xmax: 0.9 image/object/bbox/xmax: 0.9
image/object/bbox/ymin: 0.2 image/object/bbox/ymin: 0.2
image/object/bbox/ymax: 0.6 image/object/bbox/ymax: 0.6
image/object/bbox/label: 615 image/object/bbox/label: 615
image/format: 'JPEG' image/format: 'JPEG'
image/filename: 'ILSVRC2012_val_00041207.JPEG' image/filename: 'ILSVRC2012_val_00041207.JPEG'
image/encoded: <JPEG encoded string> image/encoded: <JPEG encoded string>
Args: Args:
example_serialized: scalar Tensor tf.string containing a serialized example_serialized: scalar Tensor tf.string containing a serialized
Example protocol buffer. Example protocol buffer.
Returns: Returns:
image_buffer: Tensor tf.string containing the contents of a JPEG file. image_buffer: Tensor tf.string containing the contents of a JPEG file.
label: Tensor tf.int32 containing the label. label: Tensor tf.int32 containing the label.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax]. [ymin, xmin, ymax, xmax].
text: Tensor tf.string containing the human-readable label. text: Tensor tf.string containing the human-readable label.
""" """
# Dense features in Example proto. # Dense features in Example proto.
feature_map = { feature_map = {
'image/encoded': tf.FixedLenFeature([], dtype=tf.string, "image/encoded": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
default_value=''), "image/class/label": tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1),
'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64, "image/class/text": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
default_value=-1), }
'image/class/text': tf.FixedLenFeature([], dtype=tf.string, sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
default_value=''), # Sparse features in Example proto.
} feature_map.update(
sparse_float32 = tf.VarLenFeature(dtype=tf.float32) {
# Sparse features in Example proto. k: sparse_float32
feature_map.update( for k in [
{k: sparse_float32 for k in ['image/object/bbox/xmin', "image/object/bbox/xmin",
'image/object/bbox/ymin', "image/object/bbox/ymin",
'image/object/bbox/xmax', "image/object/bbox/xmax",
'image/object/bbox/ymax']}) "image/object/bbox/ymax",
]
}
)
features = tf.parse_single_example(example_serialized, feature_map) features = tf.parse_single_example(example_serialized, feature_map)
label = tf.cast(features['image/class/label'], dtype=tf.int32) label = tf.cast(features["image/class/label"], dtype=tf.int32)
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0) xmin = tf.expand_dims(features["image/object/bbox/xmin"].values, 0)
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0) ymin = tf.expand_dims(features["image/object/bbox/ymin"].values, 0)
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0) xmax = tf.expand_dims(features["image/object/bbox/xmax"].values, 0)
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0) ymax = tf.expand_dims(features["image/object/bbox/ymax"].values, 0)
# Note that we impose an ordering of (y, x) just to make life difficult. # Note that we impose an ordering of (y, x) just to make life difficult.
bbox = tf.concat([ymin, xmin, ymax, xmax], 0) bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
# Force the variable number of bounding boxes into the shape # Force the variable number of bounding boxes into the shape
# [1, num_boxes, coords]. # [1, num_boxes, coords].
bbox = tf.expand_dims(bbox, 0) bbox = tf.expand_dims(bbox, 0)
bbox = tf.transpose(bbox, [0, 2, 1]) bbox = tf.transpose(bbox, [0, 2, 1])
return features['image/encoded'], label, bbox, features['image/class/text'] return features["image/encoded"], label, bbox, features["image/class/text"]
class ImagenetPreprocessor: class ImagenetPreprocessor:
def __init__(self, image_size, dtype, train): def __init__(self, image_size, dtype, train):
self.image_size = image_size self.image_size = image_size
self.dtype = dtype self.dtype = dtype
self.train = train self.train = train
def preprocess(self, image_buffer, bbox): def preprocess(self, image_buffer, bbox):
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train) image = preprocess_image(
return tf.cast(image, self.dtype) image_buffer,
bbox,
def parse_and_preprocess(self, value): self.image_size,
image_buffer, label_index, bbox, _ = parse_example_proto(value) self.image_size,
image = self.preprocess(image_buffer, bbox) IMAGE_DEPTH,
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH]) is_training=self.train,
return label_index, image )
return tf.cast(image, self.dtype)
def parse_and_preprocess(self, value):
image_buffer, label_index, bbox, _ = parse_example_proto(value)
image = self.preprocess(image_buffer, bbox)
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
return label_index, image

View file

@ -1,105 +1,112 @@
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py # Code derived from
from __future__ import absolute_import # tensorflow/tensorflow/models/image/imagenet/classify_image.py
from __future__ import division from __future__ import absolute_import, division, print_function
from __future__ import print_function
import math
import os.path import os.path
import sys import sys
import tarfile import tarfile
import numpy as np
from six.moves import urllib
import tensorflow as tf
import glob
import scipy.misc
import math
import sys
import horovod.tensorflow as hvd import horovod.tensorflow as hvd
import numpy as np
import tensorflow as tf
from six.moves import urllib
MODEL_DIR = '/tmp/imagenet' MODEL_DIR = "/tmp/imagenet"
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' DATA_URL = (
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
)
softmax = None softmax = None
config = tf.ConfigProto() config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank()) config.gpu_options.visible_device_list = str(hvd.local_rank())
sess = tf.Session(config=config) sess = tf.Session(config=config)
# Call this function with list of images. Each of elements should be a # Call this function with list of images. Each of elements should be a
# numpy array with values ranging from 0 to 255. # numpy array with values ranging from 0 to 255.
def get_inception_score(images, splits=10): def get_inception_score(images, splits=10):
# For convenience # For convenience
if len(images[0].shape) != 3: if len(images[0].shape) != 3:
return 0, 0 return 0, 0
# Bypassing all the assertions so that we don't end prematuraly'
# assert(type(images) == list)
# assert(type(images[0]) == np.ndarray)
# assert(len(images[0].shape) == 3)
# assert(np.max(images[0]) > 10)
# assert(np.min(images[0]) >= 0.0)
inps = []
for img in images:
img = img.astype(np.float32)
inps.append(np.expand_dims(img, 0))
bs = 1
preds = []
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
for i in range(n_batches):
sys.stdout.write(".")
sys.stdout.flush()
inp = inps[(i * bs) : min((i + 1) * bs, len(inps))]
inp = np.concatenate(inp, 0)
pred = sess.run(softmax, {"ExpandDims:0": inp})
preds.append(pred)
preds = np.concatenate(preds, 0)
scores = []
for i in range(splits):
part = preds[
(i * preds.shape[0] // splits) : ((i + 1) * preds.shape[0] // splits), :
]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return np.mean(scores), np.std(scores)
# Bypassing all the assertions so that we don't end prematuraly'
# assert(type(images) == list)
# assert(type(images[0]) == np.ndarray)
# assert(len(images[0].shape) == 3)
# assert(np.max(images[0]) > 10)
# assert(np.min(images[0]) >= 0.0)
inps = []
for img in images:
img = img.astype(np.float32)
inps.append(np.expand_dims(img, 0))
bs = 1
preds = []
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
for i in range(n_batches):
sys.stdout.write(".")
sys.stdout.flush()
inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
inp = np.concatenate(inp, 0)
pred = sess.run(softmax, {'ExpandDims:0': inp})
preds.append(pred)
preds = np.concatenate(preds, 0)
scores = []
for i in range(splits):
part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return np.mean(scores), np.std(scores)
# This function is called automatically. # This function is called automatically.
def _init_inception(): def _init_inception():
global softmax global softmax
if not os.path.exists(MODEL_DIR): if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR) os.makedirs(MODEL_DIR)
filename = DATA_URL.split('/')[-1] filename = DATA_URL.split("/")[-1]
filepath = os.path.join(MODEL_DIR, filename) filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath): if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % ( def _progress(count, block_size, total_size):
filename, float(count * block_size) / float(total_size) * 100.0)) sys.stdout.write(
sys.stdout.flush() "\r>> Downloading %s %.1f%%"
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) % (filename, float(count * block_size) / float(total_size) * 100.0)
print() )
statinfo = os.stat(filepath) sys.stdout.flush()
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
with tf.gfile.FastGFile(os.path.join( print()
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: statinfo = os.stat(filepath)
graph_def = tf.GraphDef() print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
graph_def.ParseFromString(f.read()) tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
_ = tf.import_graph_def(graph_def, name='') with tf.gfile.FastGFile(
# Works with an arbitrary minibatch size. os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
pool3 = sess.graph.get_tensor_by_name('pool_3:0') ) as f:
ops = pool3.graph.get_operations() graph_def = tf.GraphDef()
for op_idx, op in enumerate(ops): graph_def.ParseFromString(f.read())
for o in op.outputs: _ = tf.import_graph_def(graph_def, name="")
shape = o.get_shape() # Works with an arbitrary minibatch size.
shape = [s.value for s in shape] pool3 = sess.graph.get_tensor_by_name("pool_3:0")
new_shape = [] ops = pool3.graph.get_operations()
for j, s in enumerate(shape): for op_idx, op in enumerate(ops):
if s == 1 and j == 0: for o in op.outputs:
new_shape.append(None) shape = o.get_shape()
else: shape = [s.value for s in shape]
new_shape.append(s) new_shape = []
o.set_shape(tf.TensorShape(new_shape)) for j, s in enumerate(shape):
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] if s == 1 and j == 0:
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w) new_shape.append(None)
softmax = tf.nn.softmax(logits) else:
new_shape.append(s)
o.set_shape(tf.TensorShape(new_shape))
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
softmax = tf.nn.softmax(logits)
if softmax is None: if softmax is None:
_init_inception() _init_inception()

File diff suppressed because it is too large Load diff

View file

@ -1,53 +1,56 @@
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import flags
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
import os.path as osp import os.path as osp
import os
from utils import optimistic_restore, remap_restore, optimistic_remap_restore
from tqdm import tqdm
import random import random
from scipy.misc import imsave
from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader
from torch.utils.data import DataLoader
from baselines.common.tf_util import initialize
import horovod.tensorflow as hvd import horovod.tensorflow as hvd
import numpy as np
import tensorflow as tf
from baselines.common.tf_util import initialize
from data import Cifar10, Imagenet, TFImagenetLoader
from fid import get_fid_score
from inception import get_inception_score
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
from scipy.misc import imsave
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import optimistic_remap_restore
hvd.init() hvd.init()
from inception import get_inception_score
from fid import get_fid_score
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') flags.DEFINE_string(
flags.DEFINE_string('exp', 'default', 'name of experiments') "logdir", "cachedir", "location where log of experiments will be stored"
flags.DEFINE_bool('cclass', False, 'whether to condition on class') )
flags.DEFINE_string("exp", "default", "name of experiments")
flags.DEFINE_bool("cclass", False, "whether to condition on class")
# Architecture settings # Architecture settings
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not') flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent') flags.DEFINE_float("step_lr", 10.0, "Size of steps for gradient descent")
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label') flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images') flags.DEFINE_float("proj_norm", 0.05, "Maximum change of input images")
flags.DEFINE_integer('batch_size', 512, 'batch size') flags.DEFINE_integer("batch_size", 512, "batch size")
flags.DEFINE_integer('resume_iter', -1, 'resume iteration') flags.DEFINE_integer("resume_iter", -1, "resume iteration")
flags.DEFINE_integer('ensemble', 10, 'number of ensembles') flags.DEFINE_integer("ensemble", 10, "number of ensembles")
flags.DEFINE_integer('im_number', 50000, 'number of ensembles') flags.DEFINE_integer("im_number", 50000, "number of ensembles")
flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations') flags.DEFINE_integer("repeat_scale", 100, "number of repeat iterations")
flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output') flags.DEFINE_float("noise_scale", 0.005, "amount of noise to output")
flags.DEFINE_integer('idx', 0, 'save index') flags.DEFINE_integer("idx", 0, "save index")
flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing') flags.DEFINE_integer("nomix", 10, "number of intervals to stop mixing")
flags.DEFINE_bool('scaled', True, 'whether to scale noise added') flags.DEFINE_bool("scaled", True, "whether to scale noise added")
flags.DEFINE_bool('large_model', False, 'whether to use a small or large model') flags.DEFINE_bool("large_model", False, "whether to use a small or large model")
flags.DEFINE_bool('larger_model', False, 'Whether to use a large model') flags.DEFINE_bool("larger_model", False, "Whether to use a large model")
flags.DEFINE_bool('wider_model', False, 'Whether to use a large model') flags.DEFINE_bool("wider_model", False, "Whether to use a large model")
flags.DEFINE_bool('single', False, 'single ') flags.DEFINE_bool("single", False, "single ")
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single') flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull') flags.DEFINE_string("dataset", "cifar10", "cifar10 or imagenet or imagenetfull")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class InceptionReplayBuffer(object): class InceptionReplayBuffer(object):
def __init__(self, size): def __init__(self, size):
"""Create Replay buffer. """Create Replay buffer.
@ -72,14 +75,16 @@ class InceptionReplayBuffer(object):
self._label_storage.extend(list(labels)) self._label_storage.extend(list(labels))
else: else:
if batch_size + self._next_idx < self._maxsize: if batch_size + self._next_idx < self._maxsize:
self._storage[self._next_idx:self._next_idx+batch_size] = list(ims) self._storage[self._next_idx : self._next_idx + batch_size] = list(ims)
self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels) self._label_storage[self._next_idx : self._next_idx + batch_size] = (
list(labels)
)
else: else:
split_idx = self._maxsize - self._next_idx split_idx = self._maxsize - self._next_idx
self._storage[self._next_idx:] = list(ims)[:split_idx] self._storage[self._next_idx :] = list(ims)[:split_idx]
self._storage[:batch_size-split_idx] = list(ims)[split_idx:] self._storage[: batch_size - split_idx] = list(ims)[split_idx:]
self._label_storage[self._next_idx:] = list(labels)[:split_idx] self._label_storage[self._next_idx :] = list(labels)[:split_idx]
self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:] self._label_storage[: batch_size - split_idx] = list(labels)[split_idx:]
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
@ -123,12 +128,13 @@ class InceptionReplayBuffer(object):
def rescale_im(im): def rescale_im(im):
return np.clip(im * 256, 0, 255).astype(np.uint8) return np.clip(im * 256, 0, 255).astype(np.uint8)
def compute_inception(sess, target_vars): def compute_inception(sess, target_vars):
X_START = target_vars['X_START'] X_START = target_vars["X_START"]
Y_GT = target_vars['Y_GT'] Y_GT = target_vars["Y_GT"]
X_finals = target_vars['X_finals'] X_finals = target_vars["X_finals"]
NOISE_SCALE = target_vars['NOISE_SCALE'] NOISE_SCALE = target_vars["NOISE_SCALE"]
energy_noise = target_vars['energy_noise'] energy_noise = target_vars["energy_noise"]
size = FLAGS.im_number size = FLAGS.im_number
num_steps = size // 1000 num_steps = size // 1000
@ -136,16 +142,21 @@ def compute_inception(sess, target_vars):
images = [] images = []
test_ims = [] test_ims = []
if FLAGS.dataset == "cifar10": if FLAGS.dataset == "cifar10":
test_dataset = Cifar10(full=True, noise=False) test_dataset = Cifar10(full=True, noise=False)
elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull": elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
test_dataset = Imagenet(train=False) test_dataset = Imagenet(train=False)
if FLAGS.dataset != "imagenetfull": if FLAGS.dataset != "imagenetfull":
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False) test_dataloader = DataLoader(
test_dataset,
batch_size=FLAGS.batch_size,
num_workers=4,
shuffle=True,
drop_last=False,
)
else: else:
test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1) test_dataloader = TFImagenetLoader("test", FLAGS.batch_size, 0, 1)
for data_corrupt, data, label_gt in tqdm(test_dataloader): for data_corrupt, data, label_gt in tqdm(test_dataloader):
data = data.numpy() data = data.numpy()
@ -155,7 +166,6 @@ def compute_inception(sess, target_vars):
test_ims = test_ims[:60000] test_ims = test_ims[:60000]
break break
# n = min(len(images), len(test_ims)) # n = min(len(images), len(test_ims))
print(len(test_ims)) print(len(test_ims))
# fid = get_fid_score(test_ims[:30000], test_ims[-30000:]) # fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
@ -187,12 +197,14 @@ def compute_inception(sess, target_vars):
x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
label = np.random.randint(0, classes, (FLAGS.batch_size)) label = np.random.randint(0, classes, (FLAGS.batch_size))
label = identity[label] label = identity[label]
x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0] x_new = sess.run(
[x_final], {X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale}
)[0]
data_buffer.add(x_new, label) data_buffer.add(x_new, label)
else: else:
(x_init, label), idx = data_buffer.sample(FLAGS.batch_size) (x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99) keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99
label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9) label_keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9
label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size)) label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
label_corrupt = identity[label_corrupt] label_corrupt = identity[label_corrupt]
x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
@ -203,7 +215,10 @@ def compute_inception(sess, target_vars):
# else: # else:
# noise_scale = [0.7] # noise_scale = [0.7]
x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale}) x_new, e_noise = sess.run(
[x_final, energy_noise],
{X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale},
)
data_buffer.set_elms(idx, x_new, label) data_buffer.set_elms(idx, x_new, label)
if FLAGS.im_number != 50000: if FLAGS.im_number != 50000:
@ -216,14 +231,22 @@ def compute_inception(sess, target_vars):
images.extend(list(ims)) images.extend(list(ims))
saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx)) saveim = osp.join("sandbox_cachedir", FLAGS.exp, "test{}.png".format(FLAGS.idx))
ims = ims[:100] ims = ims[:100]
if FLAGS.dataset != "imagenetfull": if FLAGS.dataset != "imagenetfull":
im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3)) im_panel = (
ims.reshape((10, 10, 32, 32, 3))
.transpose((0, 2, 1, 3, 4))
.reshape((320, 320, 3))
)
else: else:
im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3)) im_panel = (
ims.reshape((10, 10, 128, 128, 3))
.transpose((0, 2, 1, 3, 4))
.reshape((1280, 1280, 3))
)
imsave(saveim, im_panel) imsave(saveim, im_panel)
print("Saved image!!!!") print("Saved image!!!!")
@ -237,8 +260,6 @@ def compute_inception(sess, target_vars):
print("FID of score {}".format(fid)) print("FID of score {}".format(fid))
def main(model_list): def main(model_list):
if FLAGS.dataset == "imagenetfull": if FLAGS.dataset == "imagenetfull":
@ -259,45 +280,55 @@ def main(model_list):
weights = [] weights = []
for i, model_num in enumerate(model_list): for i, model_num in enumerate(model_list):
weight = model.construct_weights('context_{}'.format(i)) weight = model.construct_weights("context_{}".format(i))
initialize() initialize()
save_file = osp.join(logdir, 'model_{}'.format(model_num)) save_file = osp.join(logdir, "model_{}".format(model_num))
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i)) v_list = tf.get_collection(
v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list} tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(i)
)
v_map = {
(v.name.replace("context_{}".format(i), "context_0")[:-2]): v
for v in v_list
}
saver = tf.train.Saver(v_map) saver = tf.train.Saver(v_map)
try: try:
saver.restore(sess, save_file) saver.restore(sess, save_file)
except: except BaseException:
optimistic_remap_restore(sess, save_file, i) optimistic_remap_restore(sess, save_file, i)
weights.append(weight) weights.append(weight)
if FLAGS.dataset == "imagenetfull": if FLAGS.dataset == "imagenetfull":
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32) X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
else: else:
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32) X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
if FLAGS.dataset == "cifar10": if FLAGS.dataset == "cifar10":
Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32) Y_GT = tf.placeholder(shape=(None, 10), dtype=tf.float32)
else: else:
Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32) Y_GT = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32) NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32)
X_finals = [] X_finals = []
# Seperate loops # Seperate loops
for weight in weights: for weight in weights:
X = X_START X = X_START
steps = tf.constant(0) steps = tf.constant(0)
c = lambda i, x: tf.less(i, FLAGS.num_steps)
def c(i, x):
return tf.less(i, FLAGS.num_steps)
def langevin_step(counter, X): def langevin_step(counter, X):
scale_rate = 1 scale_rate = 1
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE) X = X + tf.random_normal(
tf.shape(X),
mean=0.0,
stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE,
)
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True) energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0] x_grad = tf.gradients(energy_noise, [X])[0]
@ -305,7 +336,7 @@ def main(model_list):
if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad * scale_rate X = X - FLAGS.step_lr * x_grad * scale_rate
X = tf.clip_by_value(X, 0, 1) X = tf.clip_by_value(X, 0, 1)
counter = counter + 1 counter = counter + 1
@ -318,16 +349,16 @@ def main(model_list):
X_finals.append(X_final) X_finals.append(X_final)
target_vars = {} target_vars = {}
target_vars['X_START'] = X_START target_vars["X_START"] = X_START
target_vars['Y_GT'] = Y_GT target_vars["Y_GT"] = Y_GT
target_vars['X_finals'] = X_finals target_vars["X_finals"] = X_finals
target_vars['NOISE_SCALE'] = NOISE_SCALE target_vars["NOISE_SCALE"] = NOISE_SCALE
target_vars['energy_noise'] = energy_noise target_vars["energy_noise"] = energy_noise
compute_inception(sess, target_vars) compute_inception(sess, target_vars)
if __name__ == "__main__": if __name__ == "__main__":
# model_list = [117000, 116700] # model_list = [117000, 116700]
model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)] model_list = [FLAGS.resume_iter - 300 * i for i in range(FLAGS.ensemble)]
main(model_list) main(model_list)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

22
Makefile Normal file
View file

@ -0,0 +1,22 @@
.PHONY: install lint clean
install:
@echo "creating virtual environment..."
python3 -m venv venv
@echo "run: source venv/bin/activate"
venv/bin/pip3 install -r scripts/requirements.txt
lint:
venv/bin/python3 scripts/auto_fix.py
clean:
@echo "🧹 cleaning build artifacts and cache..."
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
find . -type f -name "*.pyc" -delete 2>/dev/null || true
find . -type f -name "*.pyo" -delete 2>/dev/null || true
find . -type f -name "*.pyd" -delete 2>/dev/null || true
find . -type f -name ".coverage" -delete 2>/dev/null || true
find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true
@echo "✨ cleanup complete!"

View file

@ -1,19 +1,27 @@
## resources and experiments on autonomous agents ## the AI toolkit
<br> <br>
* **[⬛ ai && ml tl; dr](deep_learning)** #### research
* **[⬛ machine learning history](deep_learning)**
* **[⬛ large language models](llms)** * **[⬛ large language models](llms)**
<br>
#### experiments
* **[⬛ agents on blockchains](crypto_agents)** * **[⬛ agents on blockchains](crypto_agents)**
* **[⬛ on quantum computing](EBMs)** * **[⬛ on quantum computing](EBMs)** (my adaptation of openAI's implicit generation and generalization in energy based
- my adaptation of openai's implicit generation and generalization in energy based models models)
<br> <br>
--- ---
### cool resources ### cool discussions
<br> <br>
* **[vub's response to AI 2027 and his take on defense (2025)](https://vitalik.eth.limo/general/2025/07/10/2027.html)**
* **[mr. vp jd vance at the ai action summit in paris (2025)](https://www.youtube.com/watch?v=MnKsxnP2IVk)** * **[mr. vp jd vance at the ai action summit in paris (2025)](https://www.youtube.com/watch?v=MnKsxnP2IVk)**

View file

@ -1,8 +1,9 @@
## crypto agents ## agents on blockchains
<br> <br>
* **[basic strategy workflow](strategy_workflow)** * **[basic strategy workflow (2023)](strategy_workflow)**
* **[trading on gmx (2023)](trading_on_gmx.md)**
<br> <br>
@ -12,7 +13,6 @@
<br> <br>
##### projects ##### projects
* **[ritual.net](https://ritual.net/)** * **[ritual.net](https://ritual.net/)**
@ -27,26 +27,37 @@
##### readings ##### readings
* **[the internet's notary public: why verifiability matters, by axal](https://axal.substack.com/p/the-internets-notary-public-why-verifiability)** * **[microsoft notes on ai agents](https://github.com/microsoft/generative-ai-for-beginners/tree/main/17-ai-agents)**
* **[cryptos role in the ai revolution, by pantera](https://panteracapital.com/blockchain-letter/cryptos-role-in-the-ai-revolution/)** * **[the internet's notary public: why verifiability matters, by
* **[the promise and challenges of crypto + ai applications, by vub](https://vitalik.eth.limo/general/2024/01/30/cryptoai.html)** axal](https://axal.substack.com/p/the-internets-notary-public-why-verifiability)**
* **[on training defi agents with markov chains, by bt3gl](https://mirror.xyz/go-outside.eth/DKaWYobU7q3EvZw8x01J7uEmF_E8PfNN27j0VgxQhNQ)** * **[cryptos role in the ai revolution, by
pantera](https://panteracapital.com/blockchain-letter/cryptos-role-in-the-ai-revolution/)**
* **[the promise and challenges of crypto + ai applications, by
vub](https://vitalik.eth.limo/general/2024/01/30/cryptoai.html)**
* **[on training defi agents with markov chains, by
bt3gl](https://mirror.xyz/go-outside.eth/DKaWYobU7q3EvZw8x01J7uEmF_E8PfNN27j0VgxQhNQ)**
<br> <br>
##### metas ##### metas
* **eth denver 2025 ai + agentic metas**: * **eth denver 2025 ai + agentic metas**:
* **[a new dawn for decentralized and the convergence of ai x blockchain, by n. ameline](https://www.youtube.com/watch?v=HQuGtN9zidQ)** * **[a new dawn for decentralized and the convergence of ai x blockchain, by n.
* **[upgrading agent infra with onchain perpetualaAgents, by kd conway](https://www.youtube.com/watch?v=hDayCeDA5fI)** * **[upgrading agent infra with onchain perpetualaAgents, by kd conway](https://www.youtube.com/watch?v=hDayCeDA5fI)**
* **[ai meets blockchain: payment, multi-Agent & distributed inference, by o. jaros](https://www.youtube.com/watch?v=aPTosw4hrY0)** * **[ai meets blockchain: payment, multi-Agent & distributed inference, by o.
* **[from ai agents to agentic economies, by r. bodkin](https://www.youtube.com/watch?v=Q7eaYJ9aPpI)** * **[from ai agents to agentic economies, by r. bodkin](https://www.youtube.com/watch?v=Q7eaYJ9aPpI)**
* **[building ai agents and agent economies, by d. minarsch](https://www.youtube.com/watch?v=tDVK2Q5RY0c)** * **[building ai agents and agent economies, by d. minarsch](https://www.youtube.com/watch?v=tDVK2Q5RY0c)**
* **[building ai agents on top of hedera, by j. hall](https://www.youtube.com/watch?v=h8D6vi2m8LQ)** * **[building ai agents on top of hedera, by j. hall](https://www.youtube.com/watch?v=h8D6vi2m8LQ)**
* **[identity, privacy, and security in the new age of ai agents, by m. csernai](https://www.youtube.com/watch?v=vMcaot04RQo)** * **[identity, privacy, and security in the new age of ai agents, by m.
* **[verifiable ai agents and world superintelligence, by k. wong](https://www.youtube.com/watch?v=ngkp7HTj_4A)** * **[verifiable ai agents and world superintelligence, by k. wong](https://www.youtube.com/watch?v=ngkp7HTj_4A)**
* **[how web3 can compete in ai, by g. narula](https://www.youtube.com/watch?v=oLLM1I-3fDU)** * **[how web3 can compete in ai, by g. narula](https://www.youtube.com/watch?v=oLLM1I-3fDU)**
* **[shade agents, by m. lockyer](https://www.youtube.com/watch?v=PEfJnCtrbMU)** * **[shade agents, by m. lockyer](https://www.youtube.com/watch?v=PEfJnCtrbMU)**
* **[how ai will enable the next generatin of intent, by i. yang](https://www.youtube.com/watch?v=fbc3DpI6jiA)** * **[how ai will enable the next generatin of intent, by i. yang](https://www.youtube.com/watch?v=fbc3DpI6jiA)**
* **[2025 is the year of agents, by i. polosukhin](https://www.youtube.com/watch?v=jPyzVNcQMKw)** * **[2025 is the year of agents, by i. polosukhin](https://www.youtube.com/watch?v=jPyzVNcQMKw)**
* **[our awesome decentralized AI](https://github.com/shadowy-forest/awesome-decentralized-ai)**
<br>
##### books
* **[advances in financial machine learning](../books/advances_in_financial_machine_learning.pdf)**

View file

@ -3,7 +3,8 @@
<br> <br>
<p align="center"> <p align="center">
<img width="854" src="https://user-images.githubusercontent.com/1130416/227752772-5d739fd8-1b5c-4841-a52a-7cda308fc4df.png"> <img width="854"
src="https://user-images.githubusercontent.com/1130416/227752772-5d739fd8-1b5c-4841-a52a-7cda308fc4df.png">
</p> </p>
<br> <br>

View file

@ -5,59 +5,81 @@
### A ### A
- Arbitrage: the simultaneous buying and selling of assets (e.g., cryptocurrencies) in several markets to take advantage of their price discrepancies. - Arbitrage: the simultaneous buying and selling of assets (e.g., cryptocurrencies) in several markets to take advantage
- Assets under management (AUM): the total market value of the investments that a person or entity manages on behalf of clients. of their price discrepancies.
- Assets under management (AUM): the total market value of the investments that a person or entity manages on behalf of
clients.
<br> <br>
### B ### B
- Backrunning: when an attacker attempts to have a transaction ordered immediately after a certain unconfirmed target transaction. - Backrunning: when an attacker attempts to have a transaction ordered immediately after a certain unconfirmed target
- Blocks: a block contains transaction data and the hash of the previous block ensuring immutability in the blockchain network. Each block in a blockchain contains a list of transactions in a particular order. These transactions encode the updates to the blockchain state. transaction.
- Blocks: a block contains transaction data and the hash of the previous block ensuring immutability in the blockchain
network. Each block in a blockchain contains a list of transactions in a particular order. These transactions encode the
updates to the blockchain state.
- Block time: the time interval between blocks being added to the blockchain. - Block time: the time interval between blocks being added to the blockchain.
- Broadcasting: whenever a user interacts with the blockchain, they broadcast a request to include the transaction to the network. This request is public (anyone can listen to it). - Broadcasting: whenever a user interacts with the blockchain, they broadcast a request to include the transaction to
- Builders: actors that take bundles (of pendent transactions from the mempool) and create a final block to send to (multiple) relays (setting themselves afeeRecipient to receive the blocks MEV). the network. This request is public (anyone can listen to it).
- Bundles: one or more transactions that are grouped together and executed in the order they are provided. In addition to the searcher's transaction(s), a bundle can also contain other users' pending transactions from the mempool. Bundles can target specific blocks for inclusion as well. - Builders: actors that take bundles (of pendent transactions from the mempool) and create a final block to send to
(multiple) relays (setting themselves afeeRecipient to receive the blocks MEV).
- Bundles: one or more transactions that are grouped together and executed in the order they are provided.
<br> <br>
### C ### C
- Central limit order book (CLOB): patient buyers and sellers post limit orders with the price and size that they are willing to buy or sell a given asset. Impatient buyers and sellers place market orders that run through the CLOB until the desired size is reached. - Central limit order book (CLOB): patient buyers and sellers post limit orders with the price and size that they are
- Contract address: the address hosting some source code deployed on the Ethereum blockchain, which is executed by a triggering transaction. willing to buy or sell a given asset. Impatient buyers and sellers place market orders that run through the CLOB until
- Crypto copy trading strategy: a trading strategy that uses automation to buy and sell crypto, letting you copy another trader's method. the desired size is reached.
- Contract address: the address hosting some source code deployed on the Ethereum blockchain, which is executed by a
triggering transaction.
- Crypto copy trading strategy: a trading strategy that uses automation to buy and sell crypto, letting you copy another
trader's method.
<br> <br>
### D ### D
- Derivatives: financial contracts that derive their values from underlying assets. - Derivatives: financial contracts that derive their values from underlying assets.
- Dollar-cost-averaging (DCA) strategy: a one-stop automated trading, based on time intervals, and reducing the influence of market volatility. Parameters for DCA can be: currency, fixed/maximum investment, and amount, investment frequency. - Dollar-cost-averaging (DCA) strategy: a one-stop automated trading, based on time intervals, and reducing the
influence of market volatility. Parameters for DCA can be: currency, fixed/maximum investment, and amount, investment
frequency.
<br> <br>
### E ### E
- Epoch: in the context of Ethereum's block production, in each slot (every 12 seconds), a validator is randomly chosen to propose the block in that slot. An epoch contains 32 slots. - Epoch: in the context of Ethereum's block production, in each slot (every 12 seconds), a validator is randomly chosen
- Externally owned account (EOA): an account that is a combination of public address and private key, and that can be used to send and receive Ether to/from another account. An Ethereum address is a 42-character hexadecimal address derived from the last 20 bytes of the public key of the account (with 0x appended in front). to propose the block in that slot. An epoch contains 32 slots.
- Externally owned account (EOA): an account that is a combination of public address and private key, and that can be
used to send and receive Ether to/from another account. An Ethereum address is a 42-character hexadecimal address
derived from the last 20 bytes of the public key of the account (with 0x appended in front).
<br> <br>
### F ### F
- Frontrunning: the process by which an adversary observes transactions on the network layer and acts on this information to obtain profit. - Frontrunning: the process by which an adversary observes transactions on the network layer and acts on this
information to obtain profit.
- Fully diluted valuations (FDV): the total number of tokens multiplied by the current price of a single token. - Fully diluted valuations (FDV): the total number of tokens multiplied by the current price of a single token.
- Futures: contracts used as proxy tools to speculate on the future prices of crypto assets or to hedge against their price changes. - Futures: contracts used as proxy tools to speculate on the future prices of crypto assets or to hedge against their
- Future grid trading bots: bots that automate futures trading activities based on grid trading strategies (a set of orders is placed both above and below a specific reference market price for the asset). price changes.
- Future grid trading bots: bots that automate futures trading activities based on grid trading strategies (a set of
orders is placed both above and below a specific reference market price for the asset).
<br> <br>
### G ### G
- Gas price: used somewhat like a bid, indicating an amount the user is willing to pay (per unit of execution) to have their transaction processed. - Gas price: used somewhat like a bid, indicating an amount the user is willing to pay (per unit of execution) to have
- Gwei: a small unit of the Ethereum network's Ether (ETH) cryptocurrency. A gwei or gigawei is defined as 1,000,000,000 wei, the smallest base unit of Ether. Conversely, 1 ETH represents 1 billion gwei. their transaction processed.
- Grid trading strategy: a strategy that involves placing orders above and below a set price, using a price grid of orders (which shows orders at incrementally increasing and decreasing prices). Grid trading is based on the overarching goal of buying low and selling high. - Gwei: a small unit of the Ethereum network's Ether (ETH) cryptocurrency.
- Grid trading strategy: a strategy that involves placing orders above and below a set price, using a price grid of
orders (which shows orders at incrementally increasing and decreasing prices). Grid trading is based on the overarching
goal of buying low and selling high.
<br> <br>
@ -75,10 +97,12 @@
### L ### L
- Limit orders: when one longs or shorts a contract, several execution options can be placed (usually with a fee difference). Limit orders that are set at a specific price to be traded, and there is no guarantee that the trade will be executed (see market orders and stop-loss orders). - Limit orders: when one longs or shorts a contract, several execution options can be placed (usually with a fee
- Liquidity pools: a collection of crypto assets that can be used for decentralized trading. They are essential for automated market makers (AMM), borrow-lend protocols, yield farming, synthetic assets, on-chain insurance, blockchain gaming, etc. difference). Limit orders that are set at a specific price to be traded, and there is no guarantee that the trade will
be executed (see market orders and stop-loss orders).
- Liquidity pools: a collection of crypto assets that can be used for decentralized trading.
- Liquidation threshold: the percentage at which a collateral value is counted towards the borrowing capacity. - Liquidation threshold: the percentage at which a collateral value is counted towards the borrowing capacity.
- Liquidation: when the value of a borrowed asset exceeds the collateral. Anyone can liquidate the collateral and collect the liquidation fee for themselves. - Liquidation: when the value of a borrowed asset exceeds the collateral.
- Long: traders maintain long positions, which means that they expect the price of a coin to rise in the future. - Long: traders maintain long positions, which means that they expect the price of a coin to rise in the future.
<br> <br>
@ -86,23 +110,32 @@
### M ### M
- Fully diluted market capitalization: the total token supply, multiplied by the price of a single token. - Fully diluted market capitalization: the total token supply, multiplied by the price of a single token.
- Circulating supply market capitalization: the number of tokens that are available in the market, multiplied by the price of a single token. - Circulating supply market capitalization: the number of tokens that are available in the market, multiplied by the
price of a single token.
- Margin trading: buying or sell assets with leverage. - Margin trading: buying or sell assets with leverage.
- Marginal seller: a type of seller who is willing first to leave the market if the prices are lower. - Marginal seller: a type of seller who is willing first to leave the market if the prices are lower.
- Market orders: Market orders are executed immediately at the asset's market price (see limit orders). - Market orders: Market orders are executed immediately at the asset's market price (see limit orders).
- Mean reversion strategy: a trading range (or mean reversion) strategy is based on the concept that an asset's high and low prices are a temporary effect that reverts to their mean value (average value). - Mean reversion strategy: a trading range (or mean reversion) strategy is based on the concept that an asset's high and
low prices are a temporary effect that reverts to their mean value (average value).
- Mempool: a cryptocurrency nodes mechanism for storing information on unconfirmed transactions. - Mempool: a cryptocurrency nodes mechanism for storing information on unconfirmed transactions.
- Merkle tree: a type of binary tree, composed of: 1) a set of notes with a large number of leaf nodes at the bottom, containing the underlying data, 2) a set of intermediate nodes where each node is the hash of its two children, and 3) a single root node, also formed from the hash of its two children, representing the top of the tree. - Merkle tree: a type of binary tree, composed of: 1) a set of notes with a large number of leaf nodes at the bottom,
- Minting: the process of validating information, creating a new block, and recording that information into the blockchain. containing the underlying data, 2) a set of intermediate nodes where each node is the hash of its two children, and 3) a
single root node, also formed from the hash of its two children, representing the top of the tree.
- Minting: the process of validating information, creating a new block, and recording that information into the
blockchain.
<br> <br>
### P ### P
- Perpetual contract: a contract without an expiration date, where interest rates can be calculated by methods such as Time-Weighted-Average-Price (TWAP). - Perpetual contract: a contract without an expiration date, where interest rates can be calculated by methods such as
- Priority gas auctions: bots compete against each other by binding up transaction fees (gas) to extract revenue from arbitrage opportunities, driving up user fees. Time-Weighted-Average-Price (TWAP).
- Private key: a secret number enabling a blockchain user to prove ownership on an account or contract, via a digital signature. - Priority gas auctions: bots compete against each other by binding up transaction fees (gas) to extract revenue from
- Publick key: a number generated by a one-way (hash) function from the private key, used to verify a digital signature made with the matching private key. arbitrage opportunities, driving up user fees.
- Private key: a secret number enabling a blockchain user to prove ownership on an account or contract, via a digital
signature.
- Publick key: a number generated by a one-way (hash) function from the private key, used to verify a digital signature
made with the matching private key.
- Provider: an entity that provides an abstraction for a connection to the blockchain network. - Provider: an entity that provides an abstraction for a connection to the blockchain network.
- POFPs: private order flow protocols. - POFPs: private order flow protocols.
@ -110,8 +143,9 @@
### O ### O
- Order flow: in the context of Ethereum and EVM-based blockchains, an order is anything that allows changing the state of the blockchain. - Order flow: in the context of Ethereum and EVM-based blockchains, an order is anything that allows changing the state
- Open interest: total number of futures contracts held by market participants at the end of the trading day. Used as an indicator to determine market sentiment and the strength behind price trends. of the blockchain.
- Open interest: total number of futures contracts held by market participants at the end of the trading day.
<br> <br>
@ -123,37 +157,48 @@
### S ### S
- Slots: in the context of Ethereum's block production, a slot is a time period of 12 seconds in which a randomly chosen validator has time to propose a block. - Slots: in the context of Ethereum's block production, a slot is a time period of 12 seconds in which a randomly chosen
- Smart contracts: a computer protocol intended to enforce a contract on the blockchain without third parties. They are reliant upon code (the functions) and data (the state), and they can trigger specific actions, such as transferring tokens from A to B. validator has time to propose a block.
- Sandwich attack: when slippage value is not set, this attack can happen by an actor bumping the price of an asset to an unfavorable level, executing the trade, and then returning the asset to the original price. - Smart contracts: a computer protocol intended to enforce a contract on the blockchain without third parties.
- Sandwich attack: when slippage value is not set, this attack can happen by an actor bumping the price of an asset to
an unfavorable level, executing the trade, and then returning the asset to the original price.
- Slippage: delta in pricing between the time of order and when the order is executed. - Slippage: delta in pricing between the time of order and when the order is executed.
- Short: traders maintain short positions, which means they expect the price of a coin to drop in the future. - Short: traders maintain short positions, which means they expect the price of a coin to drop in the future.
- Short squeeze: occurs when a heavily shorted stock experiences an increase in price for some unexpected reason. This situation prompts short sellers to scramble to buy the stock to cover their positions and cap their mounting losses. - Short squeeze: occurs when a heavily shorted stock experiences an increase in price for some unexpected reason.
- Spot trading: buy or selling assets for immediate delivery. - Spot trading: buy or selling assets for immediate delivery.
- Statistical trading: is the class of strategies that aim to generate profitable situations, stemming from pricing inefficiencies among financial markets. Statistical arbitrage is a strategy to obtain profit by applying past statistics. - Statistical trading: is the class of strategies that aim to generate profitable situations, stemming from pricing
- Stop-loss orders: this type of order execution places a market/limit order to close a position to restrict an investor's loss on a crypto asset. inefficiencies among financial markets. Statistical arbitrage is a strategy to obtain profit by applying past
statistics.
- Stop-loss orders: this type of order execution places a market/limit order to close a position to restrict an
investor's loss on a crypto asset.
<br> <br>
### T ### T
- otal value locked (TVL): the value of all tokens locked in various DeFi protocols such as lending platforms, DEXes, or derivatives protocols. - otal value locked (TVL): the value of all tokens locked in various DeFi protocols such as lending platforms, DEXes, or
derivatives protocols.
- Тrading volume: the total amount of traded cryptocurrency (equivalent to US dollars) during a given timeframe. - Тrading volume: the total amount of traded cryptocurrency (equivalent to US dollars) during a given timeframe.
- Transaction: on EVM-based blockchains, there the two types of transactions are normal transactions and contract interactions. - Transaction: on EVM-based blockchains, there the two types of transactions are normal transactions and contract
interactions.
- Transaction hash: a unique 66-character identifier generated with each new transaction. - Transaction hash: a unique 66-character identifier generated with each new transaction.
- Transaction ordering: blockchains usually have loose requirements for how transactions are ordered within a block, allowing attacks that benefit from certain ordering. - Transaction ordering: blockchains usually have loose requirements for how transactions are ordered within a block,
- Time-weighted average price strategy: TWAP strategy breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using evenly divided time slots between a start and end time. allowing attacks that benefit from certain ordering.
- Time-weighted average price strategy: TWAP strategy breaks up a large order and releases dynamically determined
smaller chunks of the order to the market, using evenly divided time slots between a start and end time.
<br> <br>
### V ### V
- Validation: a mathematical proof that the state change in the blockchain is consistent. To be included into a block in the blockchain, a list of transactions needs to be validated. - Validation: a mathematical proof that the state change in the blockchain is consistent.
- VTRPs: validator transaction Reordering protocols. - VTRPs: validator transaction Reordering protocols.
- Volume-weighted average price strategy: VWAP breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using historical volume profiles. - Volume-weighted average price strategy: VWAP breaks up a large order and releases dynamically determined smaller
chunks of the order to the market, using historical volume profiles.
<br> <br>
### W ### W
- Whales: individuals or institutions who hold large amounts of coins of a certain cryptocurrency, and can become powerful enough to manipulate the valuation. - Whales: individuals or institutions who hold large amounts of coins of a certain cryptocurrency, and can become
powerful enough to manipulate the valuation.

View file

@ -2,5 +2,6 @@
<br> <br>
* perform a search, for example grid search, over possible values of strategy parameters like thresholds or coefficients (using the simulator and a set of historical data) * perform a search, for example grid search, over possible values of strategy parameters like thresholds or coefficients
(using the simulator and a set of historical data)
* overfitting to historical data is a big risk (be careful with validation and test sets). * overfitting to historical data is a big risk (be careful with validation and test sets).

View file

@ -2,4 +2,5 @@
<br> <br>
* before the strategy goes live, simulation is done on new market data, in real-time (paper trading), which prevents overfitting * before the strategy goes live, simulation is done on new market data, in real-time (paper trading), which prevents
overfitting

View file

@ -2,4 +2,5 @@
<br> <br>
* come with a rule-based policy that determines what actions to take based on the current state of the market and the outpus of supervised models. * come with a rule-based policy that determines what actions to take based on the current state of the market and the
outputs of supervised models.

View file

@ -2,8 +2,12 @@
<br> <br>
* **net pnl (net profit and loss):** how much money an algorithm makes (positive) or loses (negative) over some period, minus trading fees * **net pnl (net profit and loss):** how much money an algorithm makes (positive) or loses (negative) over some period,
minus trading fees
* **alpha nad beta** * **alpha nad beta**
* **shape ratio:** the excess return per unit of risk you are taking (return on capital over the standard deviation adjusted for risk; the higher the better). * **shape ratio:** the excess return per unit of risk you are taking (return on capital over the standard deviation
* **maximum drawdown:** maximum difference between a local maximum and a subsequent local minimum as an another measure of risk. adjusted for risk; the higher the better).
* **value at risk (var):** how much capital you may lose over a given time frame with some probability, assumong normal market conditions. * **maximum drawdown:** maximum difference between a local maximum and a subsequent local minimum as an another measure
of risk.
* **value at risk (var):** how much capital you may lose over a given time frame with some probability, assumong normal
market conditions.

View file

@ -2,4 +2,5 @@
<br> <br>
* train one or more supervised learning models to predict quantities of interest that are necessary for the strategy work, for example, price prediction, quantity prediction, etc. * train one or more supervised learning models to predict quantities of interest that are necessary for the strategy
work, for example, price prediction, quantity prediction, etc.

View file

@ -13,8 +13,10 @@
<br> <br>
<img width="400" src="https://user-images.githubusercontent.com/1130416/227733463-d0dff53f-9a5f-45f3-80a4-9d9ab0d9201e.png"> <img width="400"
<img width="400" src="https://user-images.githubusercontent.com/1130416/227733575-90550afd-99f2-45cc-b6aa-fd4457910cc5.png"> src="https://user-images.githubusercontent.com/1130416/227733463-d0dff53f-9a5f-45f3-80a4-9d9ab0d9201e.png">
<img width="400"
src="https://user-images.githubusercontent.com/1130416/227733575-90550afd-99f2-45cc-b6aa-fd4457910cc5.png">
<br> <br>
@ -25,10 +27,12 @@
<br> <br>
* the order book is made of two sides, asks (sell, offers) and bids (buy). * the order book is made of two sides, asks (sell, offers) and bids (buy).
* the best ask (the lowest price someone is willing to sell ) > the best bid (the highest price someone is willing to buy). * the best ask (the lowest price someone is willing to sell ) > the best bid (the highest price someone is willing to
buy).
* the difference between the best ask and the best bid is called spread. * the difference between the best ask and the best bid is called spread.
* **market order**: best price possible, right now. it takes liquidity from the market and usually has higher fees. * **market order**: best price possible, right now. it takes liquidity from the market and usually has higher fees.
* **limit order (passive order)**: specify the price and qty you are willing to buy or sell at, and then wait for the match. * **limit order (passive order)**: specify the price and qty you are willing to buy or sell at, and then wait for the
match.
* **stop orders**: allow you to set a maximum price for your market orders. * **stop orders**: allow you to set a maximum price for your market orders.
<br> <br>

View file

@ -1,4 +1,4 @@
## ai agents ## some machine learning history
<br> <br>
@ -13,8 +13,6 @@
<br> <br>
* **[cursor ai editor](https://www.cursor.com/)**
* **[microsoft notes on ai agents](https://github.com/microsoft/generative-ai-for-beginners/tree/main/17-ai-agents)**
* **[google's jax (composable transformations of numpy programs)](https://github.com/google/jax)** * **[google's jax (composable transformations of numpy programs)](https://github.com/google/jax)**
* **[machine learning engineering open book](https://github.com/stas00/ml-engineering)** * **[machine learning engineering open book](https://github.com/stas00/ml-engineering)**
* **[advances in financial machine learning](books/advances_in_financial_machine_learning.pdf)**

View file

@ -10,9 +10,11 @@
* **[2013: atari with deep reinforcement learning](https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial)** * **[2013: atari with deep reinforcement learning](https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial)**
* **[2014: seq2seq](https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt)** * **[2014: seq2seq](https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt)**
* **[2014: adam optmizer](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L281)** * **[2014: adam
optmizer](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L281)**
* **[2015: gans](https://www.tensorflow.org/tutorials/generative/dcgan)** * **[2015: gans](https://www.tensorflow.org/tutorials/generative/dcgan)**
* **[2015: resnets](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/applications/resnet.py)** * **[2015:
resnets](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/applications/resnet.py)**
* **[2017: transformers](https://github.com/huggingface/transformers)** * **[2017: transformers](https://github.com/huggingface/transformers)**
* **[2018: bert](https://arxiv.org/abs/1810.04805)** * **[2018: bert](https://arxiv.org/abs/1810.04805)**
@ -24,30 +26,38 @@
<br> <br>
* a map consists of a set of states, a set of actions, a transition function that describes the probability of moving rom one state to another after taking an action, and a reward function that assigns a numerical reward to each state-action pair * a map consists of a set of states, a set of actions, a transition function that describes the probability of moving
rom one state to another after taking an action, and a reward function that assigns a numerical reward to each
state-action pair
* the goal of a map is to maximize its expected cumulative reward over a sequence of actions, called a policy. * the goal of a map is to maximize its expected cumulative reward over a sequence of actions, called a policy.
* a policy is a function that maps each state to a probability distribution over actions. The optimal policy is the one that maximizes the expected cumulative rewards. * a policy is a function that maps each state to a probability distribution over actions.
* the problem of reinforcement learning can be formalized using ideas from dynamical systems theory, specifically, as the optimal control of incompletely-known Markov decision processes. * the problem of reinforcement learning can be formalized using ideas from dynamical systems theory, specifically, as
the optimal control of incompletely-known Markov decision processes.
* as opposed to supervised learning, an agent must be able to learn from its own experience. and as oppose to unsupervised learning because, reinforcement learning is trying to maximize a reward signal instead of trying to find hidden structure. * as opposed to supervised learning, an agent must be able to learn from its own experience.
* the agent has to exploit what it has already experienced in order to obtain reward, but it also has to explore in order to make better action selections in the future. on a stochastic task, each action must be tried many times to gain a reliable estimate of its expected reward. * the agent has to exploit what it has already experienced in order to obtain reward, but it also has to explore in
order to make better action selections in the future. on a stochastic task, each action must be tried many times to gain
a reliable estimate of its expected reward.
* beyond the agent and the environment, one can identify four main subelements of a reinforcement learning system: a policy, a reward signal, a value function, and, optionally, a model of the environment. * beyond the agent and the environment, one can identify four main subelements of a reinforcement learning system: a
policy, a reward signal, a value function, and, optionally, a model of the environment.
* traditional reinforcement learning problems can be formulated as a markov decision process (MDP): * traditional reinforcement learning problems can be formulated as a markov decision process (MDP):
* we have an agent acting in an environment * we have an agent acting in an environment
* each step *t* the agent receives as the input the current state S_t, takes an action A_t, and receives a reward R_{t+1} and the next state S_{t+1} * each step *t* the agent receives as the input the current state S_t, takes an action A_t, and receives a reward
R_{t+1} and the next state S_{t+1}
* the agent choose the action based on some policy pi: A_t = pi(S_t) * the agent choose the action based on some policy pi: A_t = pi(S_t)
* it's our goal to find a policy that maximizes the cumulative reward Sum R_t over some finite or infinite time horizon * it's our goal to find a policy that maximizes the cumulative reward Sum R_t over some finite or infinite time horizon
<br> <br>
<img width="500" src="https://user-images.githubusercontent.com/1130416/227799494-d62aab7f-d6cf-419f-be03-1d2dbdee1853.png"> <img width="500"
src="https://user-images.githubusercontent.com/1130416/227799494-d62aab7f-d6cf-419f-be03-1d2dbdee1853.png">
<br> <br>
@ -55,7 +65,7 @@
<br> <br>
* agent is the trading agent (e.g. the human trader who opens the gui of an exchange and makes trading decision based on the current state of the exchange and their account) * agent is the trading agent (e.g.
<br> <br>
@ -65,7 +75,8 @@
* the exchange and other agents are the environment, and they are not something we can control * the exchange and other agents are the environment, and they are not something we can control
* by putting other agents together into some big complex environment, we lose the ability to explicitly model them * by putting other agents together into some big complex environment, we lose the ability to explicitly model them
* if we try to reverse-engineer the algorithms and strategies that other traders are running, put us into a multi-agent reinforcement learning (MARL) problem setting * if we try to reverse-engineer the algorithms and strategies that other traders are running, put us into a multi-agent
reinforcement learning (MARL) problem setting
<br> <br>
@ -73,11 +84,12 @@
<br> <br>
* in the case of trading on an exchange, we don't observe the complete state of the environment (e.g. other agents), so we are dealing with a partially observable markov decision process (pomdp). * in the case of trading on an exchange, we don't observe the complete state of the environment (e.g.
* what the agents observe is not the actual state S_t of the environment, but some derivation of that. * what the agents observe is not the actual state S_t of the environment, but some derivation of that.
* we can call the observation X_t, which is calculated using some function of the full state X_t ~ O(S_t) * we can call the observation X_t, which is calculated using some function of the full state X_t ~ O(S_t)
* the observation at each timestep t is simply the history of all exchange events received up to time t. * the observation at each timestep t is simply the history of all exchange events received up to time t.
* this event history can be used to build up the current exchange state, however, in order for our agent to make decisions, extra info such as account balance and open limit orders need to be included. * this event history can be used to build up the current exchange state, however, in order for our agent to make
decisions, extra info such as account balance and open limit orders need to be included.
<br> <br>
@ -85,8 +97,9 @@
<br> <br>
* hft techniques: decisions are based almost entirely on market microstructure signals. decisions are made on nanoseconds timescales and trading strategies use dedicated connections to exchanges and extremly fast but simple algorithms running fpga hardware. * hft techniques: decisions are based almost entirely on market microstructure signals.
* neural networks are slow, they can't make predictions on nanoseconds time scales, so they can't compete with the speed of hft algorithms. * neural networks are slow, they can't make predictions on nanoseconds time scales, so they can't compete with the speed
of hft algorithms.
* guess: the optimal time scale is between a few milliseconds and a few minutes. * guess: the optimal time scale is between a few milliseconds and a few minutes.
* can deep rl algorithms pick up hidden patterns? * can deep rl algorithms pick up hidden patterns?
@ -96,9 +109,11 @@
<br> <br>
* the simplest approach has 3 actions: buy, hold, and sell. this works but limits us to placing market orders and to invest a deterministic amount of money at each step. * the simplest approach has 3 actions: buy, hold, and sell.
* in the next level we would let our agents learn how much money to invest, based on the uncertainty of our model, putting us into a continuous action space. * in the next level we would let our agents learn how much money to invest, based on the uncertainty of our model,
* in the next level, we would introduce limit orders, and the agent needs to decide the level (price) and wuantity of the order, and be able to cancel orders that have not been yet matched. putting us into a continuous action space.
* in the next level, we would introduce limit orders, and the agent needs to decide the level (price) and wuantity of
the order, and be able to cancel orders that have not been yet matched.
<br> <br>
@ -106,18 +121,21 @@
<br> <br>
* there are several possible reward functions, an obvious would realized PnL (profit and loss). the agent receives a reward whenever it closes a position. * there are several possible reward functions, an obvious would realized PnL (profit and loss).
* the net profit is either negative or positive, and this is the reward signal. * the net profit is either negative or positive, and this is the reward signal.
* as the agent maximize the total cumulative reward, it learns to trade profitably. the reward function leads to the optimal policy in the limit. * as the agent maximize the total cumulative reward, it learns to trade profitably.
* however, buy and sell actions are rare compared to doing nothing; the agent needs to learn without receiving frequent feedback. * however, buy and sell actions are rare compared to doing nothing; the agent needs to learn without receiving frequent
* an alternative is unrealized pnl, which the net profit the agent would get if it were to close all of its positions immediately. feedback.
* because the unrealized pnl may change at each time step, it gives the agent more frequent feedback signals. however the direct feedback may bias the agent towards short-term actions. * an alternative is unrealized pnl, which the net profit the agent would get if it were to close all of its positions
immediately.
* because the unrealized pnl may change at each time step, it gives the agent more frequent feedback signals.
* both naively optimize for profit, but a trader may want to minimize risk (lower volatility) * both naively optimize for profit, but a trader may want to minimize risk (lower volatility)
* using the sharpe ration is one simple way to take risk into account. other way is maximum drawdown. * using the sharpe ration is one simple way to take risk into account. other way is maximum drawdown.
<br> <br>
<img width="505" src="https://user-images.githubusercontent.com/1130416/227811225-9af06c79-3f86-48e8-899c-ee5a80bc91e1.png"> <img width="505"
src="https://user-images.githubusercontent.com/1130416/227811225-9af06c79-3f86-48e8-899c-ee5a80bc91e1.png">
<br> <br>
@ -134,11 +152,14 @@
<br> <br>
* we need separate backtesting and parameter optimization steps because it was difficult for our strategies to take into account environmental factors: order book liquidity, fee structures, latencies. * we need separate backtesting and parameter optimization steps because it was difficult for our strategies to take into
* getting around environmental limitations is part of the opimization process. if we simulate the latency in the reinforcement learning environment, and this results in the agent making a mistake, the agent will get a negative rewards, forcing it to learn to work around the latencies. account environmental factors: order book liquidity, fee structures, latencies.
* by learning a model of the environment and performing rollouts using techniques like a monte carlo tree search (mcts), we could take into account potential reactions of the market (other agents) * getting around environmental limitations is part of the opimization process.
* by learning a model of the environment and performing rollouts using techniques like a monte carlo tree search (mcts),
we could take into account potential reactions of the market (other agents)
* by being smart about the data we collect from the live environment, we can continously improve our model * by being smart about the data we collect from the live environment, we can continously improve our model
* do we act optimally in the live environment to generate profits, or do we act suboptimally to gather interesting information that we can use to improve the model of our environment and other agents? * do we act optimally in the live environment to generate profits, or do we act suboptimally to gather interesting
information that we can use to improve the model of our environment and other agents?
<br> <br>
@ -147,7 +168,9 @@
<br> <br>
* some strategy may work better in a bearish environment but lose money in a bullish environment. * some strategy may work better in a bearish environment but lose money in a bullish environment.
* because rl agents are learning powerful policies parameterized by NN, they can alos learn to adapt to market conditions by seeing them in historical data, given that they are trained over long time horizon and have sufficient memory. * because rl agents are learning powerful policies parameterized by NN, they can alos learn to adapt to market
conditions by seeing them in historical data, given that they are trained over long time horizon and have sufficient
memory.
<br> <br>
@ -156,9 +179,13 @@
<br> <br>
* the trading environment is a multiplayer game with thousands of agents acting simultaneously * the trading environment is a multiplayer game with thousands of agents acting simultaneously
* understanding how to build models of other agents is only one possible we can, we can choose perfom actions in a live environment with the goal of maximizing the information grain with respect to kind policies the other agents may be following * understanding how to build models of other agents is only one possible we can, we can choose perfom actions in a live
environment with the goal of maximizing the information grain with respect to kind policies the other agents may be
following
* trading agents receive sparse rewards from the market. naively applying reward-hungry rl algorithms will fail. * trading agents receive sparse rewards from the market. naively applying reward-hungry rl algorithms will fail.
* this opens up the possibility for new algorithms and techniques, that can efficiently deal with sparse rewards. * this opens up the possibility for new algorithms and techniques, that can efficiently deal with sparse rewards.
* many of today's standard algorithms, such as dqn or a3c, use a very naive approach exploration - basically adding random noise to the policy. however, in the trading case, most states in the environment are bad, and there are only a few good ones. a naive random approach to exploration will almost never stumble upon good state-actions paris. * many of today's standard algorithms, such as dqn or a3c, use a very naive approach exploration - basically adding
* the trading environment is inherently nonstationary. market conditions change and other agent join, leave, and constantly change their strategies. random noise to the policy. however, in the trading case, most states in the environment are bad, and there are only a
few good ones. a naive random approach to exploration will almost never stumble upon good state-actions paris.
* the trading environment is inherently nonstationary.
* can we train an agent that can transit from bear to bull and then back to bear, without needing to be re-trained? * can we train an agent that can transit from bear to bull and then back to bear, without needing to be re-trained?

View file

@ -6,8 +6,10 @@
<br> <br>
* reinforcement learning is learning what to do (how to map situations to actions) so as to maximize a numerical reward signal * reinforcement learning is learning what to do (how to map situations to actions) so as to maximize a numerical reward
* an autonomous agent is a software program or system that can operate independently and make decisions on its own, without direct intervention from a human signal
* an autonomous agent is a software program or system that can operate independently and make decisions on its own,
without direct intervention from a human
<br> <br>
@ -17,10 +19,13 @@
<br> <br>
* we formalize the problem of reinforcement using ideas from dynamical system theory, as the optimal control of incompletely-known Markov decision processes. * we formalize the problem of reinforcement using ideas from dynamical system theory, as the optimal control of
* a learning agent must be able to sense the state of its environment to some extent and must be able to take actions that affect the state. incompletely-known Markov decision processes.
* a learning agent must be able to sense the state of its environment to some extent and must be able to take actions
that affect the state.
* markov decision processes are intented to include just these three aspects, sensation, action, and goal. * markov decision processes are intented to include just these three aspects, sensation, action, and goal.
* the agent has to exploit what it has already experienced in order to obtain reward, but it has also to explore in order to make better action selections in the future. * the agent has to exploit what it has already experienced in order to obtain reward, but it has also to explore in
order to make better action selections in the future.
* on a stochastic tasks, each action must be tried many times to gain a reliable estimate of its expected reward. * on a stochastic tasks, each action must be tried many times to gain a reliable estimate of its expected reward.
<br> <br>
@ -31,12 +36,17 @@
<br> <br>
* beyond the agent and the environment, 4 more elements belong to a reinforcement learning system: a policy, a reward signal, a value funtion, and a model of the environmnet. * beyond the agent and the environment, 4 more elements belong to a reinforcement learning system: a policy, a reward
* a policy defines the learning agent's way of behacing at a given time. It's a mapping from perceiv ed states of the environment to actions to be taken when in those states. in general, policies may be stochastics (specifying probabilities for each action). signal, a value funtion, and a model of the environmnet.
* a reward signal defines the goal of a reinforcement learning problem: on each time step, the environment sends to the reinforcement learning agent a single number called the reward. the agent's sole objective is to maximize the total reward over the run. * a policy defines the learning agent's way of behacing at a given time.
* a value function specifies what is good in the long run, the valye of a state in the total amount of reward an agent can expect to accumulate over the future, starting from that state * a reward signal defines the goal of a reinforcement learning problem: on each time step, the environment sends to the
reinforcement learning agent a single number called the reward. the agent's sole objective is to maximize the total
reward over the run.
* a value function specifies what is good in the long run, the valye of a state in the total amount of reward an agent
can expect to accumulate over the future, starting from that state
* a model of the environment. * a model of the environment.
* the most important feature distinguishing reinforcement learning from other types of learning is that it uses training information that evaluates the actions taken rather than instructs by giving correct actions. * the most important feature distinguishing reinforcement learning from other types of learning is that it uses training
information that evaluates the actions taken rather than instructs by giving correct actions.
<br> <br>
@ -47,7 +57,8 @@
<br> <br>
* the problem involves evaluating feedbacks and choosing different actions in different situations. * the problem involves evaluating feedbacks and choosing different actions in different situations.
* mdps are a classical formalization of sequential decision making, where actions influence not just immediate rewards, but also subsequent situations. * mdps are a classical formalization of sequential decision making, where actions influence not just immediate rewards,
but also subsequent situations.
* mdps involve delayed reward and the need to trade off immediate and delayed reward. * mdps involve delayed reward and the need to trade off immediate and delayed reward.
<br> <br>
@ -57,32 +68,40 @@
* mdps are meant to be a straightfoward framing of the problem of learning from interaction to achieve a goal. * mdps are meant to be a straightfoward framing of the problem of learning from interaction to achieve a goal.
* the learner and the decision makers is called the agent. * the learner and the decision makers is called the agent.
* the thing it interacts with, comprimising everything outside the agent, is called the environment. * the thing it interacts with, comprimising everything outside the agent, is called the environment.
* the environment gives rise to rewards, numerical values that the agent seeks to maximize over time through its choice of actions. * the environment gives rise to rewards, numerical values that the agent seeks to maximize over time through its choice
of actions.
<br> <br>
<img width="466" src="https://user-images.githubusercontent.com/1130416/228971927-3c574911-d0ca-4d2d-b795-8b0776599952.png"> <img width="466"
src="https://user-images.githubusercontent.com/1130416/228971927-3c574911-d0ca-4d2d-b795-8b0776599952.png">
<br> <br>
* the agent and the environment interact at each of a sequence of discrete steps, t = 0, 1, 2, 3... * the agent and the environment interact at each of a sequence of discrete steps, t = 0, 1, 2, 3...
* at each time step t, the agent receives some representation of the environments state St * at each time step t, the agent receives some representation of the environments state St
* on that basis, the agent selects an action At * on that basis, the agent selects an action At
* one step later, in part of a consequence of its action, the agent receives a numerical rewards and finds itself in a new state. * one step later, in part of a consequence of its action, the agent receives a numerical rewards and finds itself in a
new state.
* the mdp and the agent together give rise to a sequence (trajectory) * the mdp and the agent together give rise to a sequence (trajectory)
* in a finite mdp, the set of states, actions, and rewards all have a finite number of elements. in this case, the random variables R and S have well defined discrete probability distributions dependent only on the proceding state and action. * in a finite mdp, the set of states, actions, and rewards all have a finite number of elements.
* in a markov decision process, the probabilities given by p completely characterize the environment's dynamics. * in a markov decision process, the probabilities given by p completely characterize the environment's dynamics.
* the state must include information about all aspects of the past agent-environment interaction that make a differnce for the future. * the state must include information about all aspects of the past agent-environment interaction that make a differnce
* anything that cannot be changed arbitrarily by the agent is considered to be outside of it and thus part of its environment. for the future.
* anything that cannot be changed arbitrarily by the agent is considered to be outside of it and thus part of its
environment.
<br> <br>
##### goals and rewards ##### goals and rewards
* each episode ends in a special state called the terminal state, followed by a reset to a standard starting state or to a sample from a standard distribution of starting states. * each episode ends in a special state called the terminal state, followed by a reset to a standard starting state or to
* almost all reinforcement learning algorithms involve estimating value functions—functions of states (or of stateaction pairs) that estimate how good it is for the agent to be in a given state (or how good it is to perform a given action in a given state). a sample from a standard distribution of starting states.
* the Bellman equation averages over all the possibilities, weighting each by its probability of occurring. tt states that the value of the start state must equal the * almost all reinforcement learning algorithms involve estimating value functions—functions of states (or of
stateaction pairs) that estimate how good it is for the agent to be in a given state (or how good it is to perform a
given action in a given state).
* the Bellman equation averages over all the possibilities, weighting each by its probability of occurring.
(discounted) value of the expected next state, plus the reward expected along the way. (discounted) value of the expected next state, plus the reward expected along the way.
* solving a reinforcement learning task means finding a policy that achieves a lot of reward over the long run. * solving a reinforcement learning task means finding a policy that achieves a lot of reward over the long run.
@ -92,10 +111,14 @@
### dynamic programming ### dynamic programming
* collection of algorithms that can be used to compute optimal policies given a perfect model of the environment as a mdp. * collection of algorithms that can be used to compute optimal policies given a perfect model of the environment as a
* a common way of obtaining approximate solutions for tasks with continuous states and actions is to quantize the state and action spaces and then apply finite-state DP methods. mdp.
* a common way of obtaining approximate solutions for tasks with continuous states and actions is to quantize the state
and action spaces and then apply finite-state DP methods.
* the reason for computing the value function for a policy is to help find better policies. * the reason for computing the value function for a policy is to help find better policies.
* asynchronous DP algorithms are in-place iterative DP algorithms that are not organized in terms of systematic sweeps of the state set. these algorithms update the values of states in any order whatsoever, using whatever values of other states happen to be available. the values of some states may be updated several times before the values of others ar * asynchronous DP algorithms are in-place iterative DP algorithms that are not organized in terms of systematic sweeps
of the state set. these algorithms update the values of states in any order whatsoever, using whatever values of other
states happen to be available. the values of some states may be updated several times before the values of others ar
* policy evaluation refers to the (typi- cally) iterative computation of the value function for a given policy. * policy evaluation refers to the (typi- cally) iterative computation of the value function for a given policy.
* policy improvement refers to the computation of an improved policy given the value function for that policy. * policy improvement refers to the computation of an improved policy given the value function for that policy.
@ -103,9 +126,13 @@
##### generalized policy interaction ##### generalized policy interaction
* policy iteration consists of two simultaneous, interacting processes, one making the value function consistent with the current policy (policy evaluation), and the other making the policy greedy with respect to the current value function (policy improvement). * policy iteration consists of two simultaneous, interacting processes, one making the value function consistent with
* generalized policy iteration (GPI) refers to the general idea of letting policy-evaluation and policy-improvement processes interact, independent of the granularity and other details of the two processes. the current policy (policy evaluation), and the other making the policy greedy with respect to the current value
* DP is sometimes thought to be of limited applicability because of the curse of dimen- sionality, the fact that the number of states often grows exponentially with the number of state variables function (policy improvement).
* generalized policy iteration (GPI) refers to the general idea of letting policy-evaluation and policy-improvement
processes interact, independent of the granularity and other details of the two processes.
* DP is sometimes thought to be of limited applicability because of the curse of dimen- sionality, the fact that the
number of states often grows exponentially with the number of state variables
<br> <br>

View file

@ -2,9 +2,9 @@
<br> <br>
* **[opeanai](opeanai)** * **[google's gemini](gemini)**
* **[openAI](openAI)**
* **[claude](claude)** * **[claude](claude)**
* **[gemini](gemini)**
* **[deepseek](deepseek)** * **[deepseek](deepseek)**
<br> <br>
@ -17,7 +17,7 @@
#### articles #### articles
* **[people cannot distinguish gpt-4 from a human in a turing test, by c. jones et al (2024)](https://arxiv.org/pdf/2405.08007)** * **[people cannot distinguish gpt-4 from a human in a turing test, by c.
<br> <br>

View file

@ -2,6 +2,12 @@
<br> <br>
<br>
---
### cool resources ### cool resources
<br> <br>

View file

@ -3,6 +3,8 @@
<br> <br>
<p align="center"> <p align="center">
<img src="https://github.com/user-attachments/assets/42b8c4ac-4359-422a-a0a4-dd4ff0ec6e75" width="60%" align="center" style="padding:1px;border:1px solid black;" /> <img src="https://github.com/user-attachments/assets/42b8c4ac-4359-422a-a0a4-dd4ff0ec6e75" width="60%" align="center"
<img src="https://github.com/user-attachments/assets/a1b2b912-8700-439f-8ad8-db415c94ad0b" width="60%" align="center" style="padding:1px;border:1px solid black;" /> style="padding:1px;border:1px solid black;" />
<img src="https://github.com/user-attachments/assets/a1b2b912-8700-439f-8ad8-db415c94ad0b" width="60%" align="center"
style="padding:1px;border:1px solid black;" />
</p> </p>

View file

@ -1,16 +1,25 @@
## openai ## openAI
<br> <br>
<br>
---
### cool resources ### cool resources
<br> <br>
* **[vscode chatgpt plugin](https://github.com/mpociot/chatgpt-vscode) (and [here](https://marketplace.visualstudio.com/items?itemName=timkmecl.chatgpt))** * **[vscode chatgpt plugin](https://github.com/mpociot/chatgpt-vscode)**
* **[scispace extension (paper explainer)](https://chrome.google.com/webstore/detail/scispace-copilot/cipccbpjpemcnijhjcdjmkjhmhniiick/related)** * **[scispace extension (paper
explainer)](https://chrome.google.com/webstore/detail/scispace-copilot/cipccbpjpemcnijhjcdjmkjhmhniiick/related)**
* **[fix python bugs](https://platform.openai.com/playground/p/default-fix-python-bugs?model=code-davinci-002)** * **[fix python bugs](https://platform.openai.com/playground/p/default-fix-python-bugs?model=code-davinci-002)**
* **[explain code](https://platform.openai.com/playground/p/default-explain-code?model=code-davinci-002)** * **[explain code](https://platform.openai.com/playground/p/default-explain-code?model=code-davinci-002)**
* **[translate code](https://platform.openai.com/playground/p/default-translate-code?model=code-davinci-002)** * **[translate code](https://platform.openai.com/playground/p/default-translate-code?model=code-davinci-002)**
* **[translate sql](https://platform.openai.com/playground/p/default-sql-translate?model=code-davinci-002)** * **[translate sql](https://platform.openai.com/playground/p/default-sql-translate?model=code-davinci-002)**
* **[calculate time complexity](https://platform.openai.com/playground/p/default-time-complexity?model=text-davinci-003)** * **[calculate time
* **[text to programmatic command](https://platform.openai.com/playground/p/default-text-to-command?model=text-davinci-003)** complexity](https://platform.openai.com/playground/p/default-time-complexity?model=text-davinci-003)**
* **[text to programmatic
command](https://platform.openai.com/playground/p/default-text-to-command?model=text-davinci-003)**

68
pyproject.toml Normal file
View file

@ -0,0 +1,68 @@
[tool.black]
line-length = 88
target-version = ['py39']
include = '\.pyi?$'
extend-exclude = '''
/(
# directories
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| venv
| _build
| buck-out
| build
| dist
)/
'''
[tool.isort]
profile = "black"
multi_line_output = 3
line_length = 88
known_first_party = []
known_third_party = []
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
skip = ["venv", ".venv", "__pycache__"]
[tool.autopep8]
max_line_length = 88
aggressive = 2
experimental = true
[tool.autoflake]
remove-all-unused-imports = true
remove-unused-variables = true
remove-duplicate-keys = true
ignore-init-module-imports = true
[tool.flake8]
max-line-length = 88
extend-ignore = ["E203", "W503"]
exclude = [
".git",
"__pycache__",
"venv",
".venv",
"build",
"dist",
".eggs"
]
[tool.mypy]
python_version = "3.9"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = true
strict_equality = true

628
scripts/auto_fix.py Executable file
View file

@ -0,0 +1,628 @@
#!/usr/bin/env python3
"""
this script fixes common quality issues including:
- code formatting (black, isort, autopep8, autoflake)
- markdown formatting (line length, trailing whitespace)
- markdown link validation (internal and external)
- python code quality issues
- import organization
"""
import os
import re
import subprocess
import time
from pathlib import Path
import requests
class AutoFixer:
def __init__(self):
self.venv_path = "venv"
self.fixes_applied = 0
self.errors_encountered = 0
self.link_report = {
"total_links": 0,
"internal_links": 0,
"external_links": 0,
"broken_internal": 0,
"broken_external": 0,
"broken_links": [],
}
def fix_python_code(self) -> bool:
print("\n🐍 fixing python code...")
python_files = list(Path(".").rglob("*.py"))
python_files = [
f
for f in python_files
if not any(
part.startswith(".") or part in ["venv", "__pycache__", ".venv"]
for part in f.parts
)
]
if not python_files:
print(" no python files found to fix")
return True
print(f" found {len(python_files)} Python files to fix")
print("🔧 autoflake - removing unused imports...")
if self._run_autoflake(python_files):
self.fixes_applied += 1
print("✅ autoflake completed")
else:
print("⚠️ autoflake had issues")
print("🔧 autopep8 - fixing code style...")
if self._run_autopep8(python_files):
self.fixes_applied += 1
print("✅ autopep8 completed")
else:
print("⚠️ autopep8 had issues")
print("🔧 isort - organizing imports...")
if self._run_isort(python_files):
self.fixes_applied += 1
print("✅ isort completed")
else:
print("⚠️ isort had issues")
print("🔧 black - applying consistent formatting...")
if self._run_black(python_files):
self.fixes_applied += 1
print("✅ black completed")
else:
print("⚠️ black had issues")
return True
def _run_autoflake(self, python_files: list[Path]) -> bool:
try:
for file_path in python_files:
cmd = [
f"{self.venv_path}/bin/autoflake",
"--in-place",
"--remove-all-unused-imports",
"--remove-unused-variables",
str(file_path),
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"autoflake warning for {file_path}: {result.stderr}")
return True
except Exception as e:
print(f"autoflake error: {e}")
def _run_autopep8(self, python_files: list[Path]) -> bool:
try:
for file_path in python_files:
cmd = [
f"{self.venv_path}/bin/autopep8",
"--in-place",
"--aggressive",
"--aggressive",
str(file_path),
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"autopep8 warning for {file_path}: {result.stderr}")
return True
except Exception as e:
print(f"autopep8 error: {e}")
def _run_isort(self, python_files: list[Path]) -> bool:
try:
for file_path in python_files:
cmd = [f"{self.venv_path}/bin/isort", str(file_path)]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"isort warning for {file_path}: {result.stderr}")
return True
except Exception as e:
print(f"isort error: {e}")
def _run_black(self, python_files: list[Path]) -> bool:
try:
import subprocess
for file_path in python_files:
cmd = [f"{self.venv_path}/bin/black", str(file_path)]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"black warning for {file_path}: {result.stderr}")
return True
except Exception as e:
print(f"black error: {e}")
def fix_markdown_files(self) -> bool:
print("\n📝 fixing markdown files...")
markdown_files = list(Path(".").rglob("*.md"))
markdown_files = [
f
for f in markdown_files
if not any(
part.startswith(".") or part in ["venv", "__pycache__", ".venv"]
for part in f.parts
)
]
if not markdown_files:
print(" no markdown files found to fix")
return True
print(f" found {len(markdown_files)} markdown files to fix")
print("🔗 checking markdown links...")
self.check_markdown_links(markdown_files)
self.fix_common_link_issues(markdown_files)
for file_path in markdown_files:
if self.fix_single_markdown_file(file_path):
self.fixes_applied += 1
return True
def check_markdown_links(self, markdown_files: list[Path]) -> None:
all_links = []
for file_path in markdown_files:
links = self.extract_links_from_markdown(file_path)
for link in links:
link["source_file"] = file_path
all_links.append(link)
if not all_links:
print(" no links found in markdown files")
return
self.link_report["total_links"] = len(all_links)
print(f" found {len(all_links)} links to check")
internal_links = [link for link in all_links if not link["is_external"]]
self.link_report["internal_links"] = len(internal_links)
if internal_links:
print(f"🔍 checking {len(internal_links)} internal links...")
self.check_internal_links(internal_links)
external_links = [link for link in all_links if link["is_external"]]
self.link_report["external_links"] = len(external_links)
if external_links:
print(f"🌐 checking {len(external_links)} external links...")
self.check_external_links(external_links)
def extract_links_from_markdown(self, file_path: Path) -> list[dict]:
links = []
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
link_pattern = r"\[([^\]]+)\]\(([^)]+)\)"
matches = re.findall(link_pattern, content)
for text, url in matches:
is_external = url.startswith(("http://", "https://", "mailto:"))
links.append(
{
"text": text.strip(),
"url": url.strip(),
"is_external": is_external,
"line_number": self.get_line_number_for_link(
content, text, url
),
}
)
except Exception as e:
print(f"❌ error reading {file_path}: {e}")
return links
def get_line_number_for_link(self, content: str, text: str, url: str) -> int:
lines = content.split("\n")
for i, line in enumerate(lines, 1):
if f"[{text}]({url})" in line:
return i
return 0
def check_internal_links(self, internal_links: list[dict]) -> None:
broken_links = []
for link in internal_links:
source_file = link["source_file"]
url = link["url"]
text = link["text"]
line_num = link["line_number"]
if url.startswith("../"):
target_path = source_file.parent.parent / url[3:]
elif url.startswith("./"):
target_path = source_file.parent / url[2:]
else:
target_path = source_file.parent / url
if not target_path.exists():
broken_link_info = {
"source_file": source_file,
"text": text,
"url": url,
"line_number": line_num,
"target_path": target_path,
"issue": "File not found",
"type": "internal",
}
broken_links.append(broken_link_info)
self.link_report["broken_links"].append(broken_link_info)
self.link_report["broken_internal"] = len(broken_links)
if broken_links:
print(f"❌ found {len(broken_links)} broken internal links:")
for link in broken_links:
print(
f" - {link['source_file']}:{link['line_number']} - '{link['text']}' -> {link['url']} ({link['issue']})"
)
else:
print("✅ all internal links are valid")
def check_external_links(self, external_links: list[dict]) -> None:
broken_links = []
checked_count = 0
for link in external_links:
url = link["url"]
text = link["text"]
source_file = link["source_file"]
line_num = link["line_number"]
try:
time.sleep(0.1)
response = requests.head(url, timeout=10, allow_redirects=True)
checked_count += 1
if response.status_code >= 400:
if (
response.status_code == 429
or response.status_code == 403
or response.status_code == 443
):
if checked_count % 10 == 0:
print(
f" checked {checked_count}/{len(external_links)} external links..."
)
continue
broken_link_info = {
"source_file": source_file,
"text": text,
"url": url,
"line_number": line_num,
"status_code": response.status_code,
"issue": f"HTTP {response.status_code}",
"type": "external",
}
broken_links.append(broken_link_info)
self.link_report["broken_links"].append(broken_link_info)
if checked_count % 10 == 0:
print(
f" checked {checked_count}/{len(external_links)} external links..."
)
except requests.exceptions.RequestException as e:
broken_link_info = {
"source_file": source_file,
"text": text,
"url": url,
"line_number": line_num,
"issue": f"Connection error: {str(e)}",
"type": "external",
}
broken_links.append(broken_link_info)
self.link_report["broken_links"].append(broken_link_info)
except Exception as e:
broken_link_info = {
"source_file": source_file,
"text": text,
"url": url,
"line_number": line_num,
"issue": f"Error: {str(e)}",
"type": "external",
}
broken_links.append(broken_link_info)
self.link_report["broken_links"].append(broken_link_info)
self.link_report["broken_external"] = len(broken_links)
if broken_links:
print(f"❌ found {len(broken_links)} broken external links:")
for link in broken_links:
print(
f" - {link['source_file']}:{link['line_number']} - '{link['text']}' -> {link['url']} ({link['issue']})"
)
else:
print("✅ all external links are accessible")
print(f" checked {checked_count} external links")
def fix_common_link_issues(self, markdown_files: list[Path]) -> None:
print("🔧 fixing common link issues...")
fixed_count = 0
for file_path in markdown_files:
if self.fix_links_in_file(file_path):
fixed_count += 1
if fixed_count > 0:
print(f"✅ fixed links in {fixed_count} files")
self.fixes_applied += 1
else:
print(" no link issues found to fix")
def fix_links_in_file(self, file_path: Path) -> bool:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
changes_made = False
eth_pattern = r"([a-zA-Z0-9]+\.eth)"
if re.search(eth_pattern, content):
content = re.sub(eth_pattern, r"\1".replace(".eth", ""), content)
changes_made = True
double_space_pattern = r"\[([^\]]+)\]\(([^)]+)\)"
def fix_spaces(match):
text = match.group(1).strip()
url = match.group(2).strip()
if text != match.group(1) or url != match.group(2):
return f"[{text}]({url})"
return match.group(0)
new_content = re.sub(double_space_pattern, fix_spaces, content)
if new_content != content:
content = new_content
changes_made = True
if changes_made:
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
print(f" 🔧 fixed links in {file_path}")
return True
except Exception as e:
print(f"❌ error fixing links in {file_path}: {e}")
return False
def fix_single_markdown_file(self, file_path: Path) -> bool:
try:
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
lines.copy()
fixed_lines = []
changes_made = False
in_code_block = False
for _, line in enumerate(lines):
line = line.rstrip()
if not line.strip():
fixed_lines.append("\n")
continue
if line.startswith("```"):
in_code_block = not in_code_block
fixed_lines.append(line + "\n")
continue
if in_code_block:
fixed_lines.append(line + "\n")
continue
if line.strip().startswith(("- ", "* ", "+ ", "1. ")):
if len(line) > 120:
broken_line = self._break_list_item(line)
if broken_line != line:
changes_made = True
print(f" breaking long list item in {file_path}")
fixed_lines.append(broken_line + "\n")
else:
fixed_lines.append(line + "\n")
continue
if len(line) > 120:
broken_lines = self.break_long_line(line)
if broken_lines != line:
changes_made = True
print(
f" breaking long line in {file_path}: {len(line)} chars -> {len(broken_lines.split(chr(10))[0])} chars"
)
for broken_line in broken_lines.split("\n"):
if broken_line.strip():
fixed_lines.append(broken_line + "\n")
else:
fixed_lines.append("\n")
else:
fixed_lines.append(line + "\n")
if changes_made:
with open(file_path, "w", encoding="utf-8") as f:
f.writelines(fixed_lines)
print(f" ✅ fixed {file_path}")
return True
else:
print(f" no changes needed in {file_path}")
except Exception as e:
print(f"❌ error fixing {file_path}: {e}")
self.errors_encountered += 1
def break_long_line(self, line: str) -> str:
if len(line) <= 120:
return line
if ". " in line:
parts = line.split(". ")
if len(parts[0]) <= 120:
remaining = ". ".join(parts[1:])
if len(remaining) <= 120:
return parts[0] + ". " + remaining
else:
broken_remaining = self._break_at_words(remaining)
return parts[0] + ".\n" + broken_remaining
return self._break_at_words(line)
def _break_at_words(self, line: str) -> str:
words = line.split()
result = []
current_line = ""
for word in words:
if current_line and len(current_line + " " + word) > 120:
if current_line:
result.append(current_line)
current_line = word
else:
if current_line:
current_line += " " + word
else:
current_line = word
if current_line:
result.append(current_line)
return "\n".join(result)
def _break_list_item(self, line: str) -> str:
marker_end = 0
for i, char in enumerate(line):
if char in "-*+" or (char.isdigit() and line[i + 1 : i + 3] == ". "):
marker_end = line.find(" ", i)
if marker_end == -1:
marker_end = len(line)
break
if marker_end == 0:
return self.break_long_line(line)
marker = line[: marker_end + 1]
content = line[marker_end + 1 :]
if len(content) <= 120 - len(marker):
return line
broken_content = self._break_at_words(content)
if "\n" in broken_content:
indent = " " * len(marker)
lines = broken_content.split("\n")
result = [marker + lines[0]]
for continuation_line in lines[1:]:
result.append(indent + continuation_line)
return "\n".join(result)
return line
def fix_trailing_whitespace(self) -> bool:
print("\n🧹 fixing trailing whitespace...")
text_extensions = {".py", ".md", ".txt", ".rst", ".yml", ".yaml", ".json"}
files_fixed = 0
for root, dirs, files in os.walk("."):
dirs[:] = [
d
for d in dirs
if not d.startswith(".")
and d not in ["venv", "__pycache__", ".venv", "node_modules"]
]
for file in files:
if any(file.endswith(ext) for ext in text_extensions):
file_path = os.path.join(root, file)
if self.fix_file_trailing_whitespace(file_path):
files_fixed += 1
if files_fixed > 0:
self.fixes_applied += 1
print(f" fixed trailing whitespace in {files_fixed} files")
return True
def fix_file_trailing_whitespace(self, file_path: str) -> bool:
try:
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
original_lines = lines.copy()
fixed_lines = []
for line in lines:
fixed_line = line.rstrip() + "\n"
fixed_lines.append(fixed_line)
if fixed_lines != original_lines:
with open(file_path, "w", encoding="utf-8") as f:
f.writelines(fixed_lines)
print(f" fixed trailing whitespace in {file_path}")
return True
except Exception as e:
print(f"warning: could not fix {file_path}: {e}")
def run_all_fixes(self) -> bool:
print("🚀 starting auto-fix process...")
success = True
success &= self.fix_trailing_whitespace()
success &= self.fix_python_code()
success &= self.fix_markdown_files()
return success
def print_summary(self):
print("\n" + "=" * 50)
print("🎯 AUTO-FIX SUMMARY")
print("=" * 50)
print(f"✅ fixes applied: {self.fixes_applied}")
print(f"❌ errors encountered: {self.errors_encountered}")
if self.link_report["total_links"] > 0:
print("\n🔗 LINK CHECK SUMMARY")
print("-" * 30)
print(f"📊 total links found: {self.link_report['total_links']}")
print(f"🔍 internal links: {self.link_report['internal_links']}")
print(f"🌐 external links: {self.link_report['external_links']}")
print(f"❌ broken internal: {self.link_report['broken_internal']}")
print(f"❌ broken external: {self.link_report['broken_external']}")
if self.link_report["broken_links"]:
print(f"\n⚠️ broken links found:")
for link in self.link_report["broken_links"]:
print(
f" - {link['source_file']}:{link['line_number']} - '{link['text']}' -> {link['url']} ({link['issue']})"
)
if (
self.errors_encountered == 0
and self.link_report["broken_internal"] == 0
and self.link_report["broken_external"] == 0
):
print("\n🎉 all fixes completed successfully and all links are working!")
elif self.errors_encountered == 0:
print(f"\n✅ all fixes completed successfully!")
print(
f"⚠️ but {self.link_report['broken_internal'] + self.link_report['broken_external']} broken links were found"
)
else:
print(f"\n⚠️ {self.errors_encountered} errors occurred during fixing")
def main():
fixer = AutoFixer()
fixer.run_all_fixes()
fixer.print_summary()
if __name__ == "__main__":
main()

10
scripts/requirements.txt Normal file
View file

@ -0,0 +1,10 @@
textstat>=0.7.3
requests>=2.28.0
beautifulsoup4>=4.11.0
markdown>=3.4.0
black>=23.0.0
isort>=5.12.0
flake8>=6.0.0
mypy>=1.0.0
autopep8>=2.0.0
autoflake>=2.0.0