mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-08-15 17:50:12 -04:00
chores: refactor for the new ai research, add linter, gh action, etc (#27)
This commit is contained in:
parent
fb4ab80dc3
commit
d5467e559f
40 changed files with 5177 additions and 2476 deletions
0
.github/.keep
vendored
Normal file
0
.github/.keep
vendored
Normal file
68
.github/workflows/auto-fix.yml
vendored
Normal file
68
.github/workflows/auto-fix.yml
vendored
Normal file
|
@ -0,0 +1,68 @@
|
|||
name: 👾 auto-fix code quality
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
workflow_dispatch: # allow manual triggering
|
||||
|
||||
jobs:
|
||||
auto-fix:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # full history for better diff detection
|
||||
|
||||
- name: set up python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.9'
|
||||
cache: 'pip'
|
||||
|
||||
- name: install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r scripts/requirements.txt
|
||||
|
||||
- name: install code quality tools
|
||||
run: |
|
||||
pip install black isort autopep8 autoflake
|
||||
|
||||
- name: run auto-fix script
|
||||
run: |
|
||||
python scripts/auto_fix.py
|
||||
|
||||
- name: check for changes
|
||||
id: check_changes
|
||||
run: |
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
echo "changes=true" >> $GITHUB_OUTPUT
|
||||
echo "files were modified by auto-fix script"
|
||||
git status --porcelain
|
||||
else
|
||||
echo "changes=false" >> $GITHUB_OUTPUT
|
||||
echo "no files were modified"
|
||||
fi
|
||||
|
||||
- name: commit and push changes (if any)
|
||||
if: steps.check_changes.outputs.changes == 'true'
|
||||
run: |
|
||||
git config --local user.email "action@github.com"
|
||||
git config --local user.name "github action"
|
||||
git add -a
|
||||
git commit -m "🔧 auto-fix code quality issues
|
||||
|
||||
- applied black formatting
|
||||
- organized imports with isort
|
||||
- fixed code style with autopep8
|
||||
- removed unused imports with autoflake
|
||||
- fixed markdown formatting
|
||||
- validated and fixed links
|
||||
- removed trailing whitespace
|
||||
|
||||
auto-generated by github actions"
|
||||
git push
|
204
.gitignore
vendored
Normal file
204
.gitignore
vendored
Normal file
|
@ -0,0 +1,204 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Django stuff (keeping in case of web components)
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff (keeping in case of web components)
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# poetry
|
||||
poetry.lock
|
||||
|
||||
# pdm
|
||||
pdm.lock
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
*.iml
|
||||
*.ipr
|
||||
*.iws
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
Icon
|
||||
._*
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# Windows
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
*.stackdump
|
||||
[Dd]esktop.ini
|
||||
$RECYCLE.BIN/
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
*.lnk
|
||||
|
||||
# Linux
|
||||
*~
|
||||
.fuse_hidden*
|
||||
.directory
|
||||
.Trash-*
|
||||
.nfs*
|
||||
|
||||
# Machine Learning specific
|
||||
*.pkl
|
||||
*.pickle
|
||||
*.joblib
|
||||
*.h5
|
||||
*.hdf5
|
||||
*.model
|
||||
*.weights
|
||||
*.ckpt
|
||||
*.pth
|
||||
*.pt
|
||||
*.onnx
|
||||
*.tflite
|
||||
*.pb
|
||||
|
||||
# Data files
|
||||
*.csv
|
||||
*.json
|
||||
*.parquet
|
||||
*.feather
|
||||
*.hdf
|
||||
*.xlsx
|
||||
*.xls
|
||||
|
||||
# Large model files
|
||||
models/
|
||||
checkpoints/
|
||||
runs/
|
||||
logs/
|
||||
wandb/
|
||||
|
||||
# Environment variables
|
||||
.env.local
|
||||
.env.development
|
||||
.env.test
|
||||
.env.production
|
||||
|
||||
# IDE specific
|
||||
*.swp
|
||||
*.swo
|
106
EBMs/README.md
106
EBMs/README.md
|
@ -1,12 +1,15 @@
|
|||
## quantum ai: training energy-based-models using openai
|
||||
## quantum ai: training energy-based-models using openAI
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
|
||||
|
||||
#### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in
|
||||
energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
|
||||
|
||||
<br>
|
||||
|
||||
---
|
||||
|
||||
### installing
|
||||
|
||||
<br>
|
||||
|
@ -19,7 +22,8 @@ brew install pkg-config
|
|||
|
||||
<br>
|
||||
|
||||
* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this problem (`PMIX ERROR: ERROR`) that can be fixed with:
|
||||
* there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this
|
||||
problem (`PMIX ERROR: ERROR`) that can be fixed with:
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -40,7 +44,8 @@ pip install -r requirements.txt
|
|||
```
|
||||
<br>
|
||||
|
||||
* note that this is an adapted requirement file since the **[openai's original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
|
||||
* note that this is an adapted requirement file since the **[openai's
|
||||
original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
|
||||
* finally, download and install **[mujoco](https://www.roboti.us/index.html)**
|
||||
* you will also need to register for a license, which asks for a machine ID
|
||||
* the documentation on the website is incomplete, so just download the suggested script and run:
|
||||
|
@ -64,7 +69,8 @@ mv getid_osx getid_osx.dms
|
|||
|
||||
<br>
|
||||
|
||||
* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder `cachedir`:
|
||||
* download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder
|
||||
`cachedir`:
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -78,7 +84,8 @@ mkdir cachedir
|
|||
|
||||
<br>
|
||||
|
||||
* openai's original code contains **[hardcoded constants that only work on Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
|
||||
* openai's original code contains **[hardcoded constants that only work on
|
||||
Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
|
||||
* i changed this to a constant (`ROOT_DIR = "./results"`) in the top of `data.py`
|
||||
|
||||
<br>
|
||||
|
@ -87,7 +94,8 @@ mkdir cachedir
|
|||
|
||||
<br>
|
||||
|
||||
* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased substantially by using multiple different workers by running each command:
|
||||
* all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased
|
||||
substantially by using multiple different workers by running each command:
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -102,7 +110,8 @@ mpiexec -n <worker_num> <command>
|
|||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model
|
||||
python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
|
||||
--zero_kl --replay_batch --large_model
|
||||
```
|
||||
|
||||
* this should generate the following output:
|
||||
|
@ -112,7 +121,8 @@ python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_si
|
|||
```bash
|
||||
Instructions for updating:
|
||||
Use tf.gfile.GFile.
|
||||
2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
|
||||
2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is
|
||||
deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
|
||||
64 batch size
|
||||
Local rank: 0 1
|
||||
Loading data...
|
||||
|
@ -121,11 +131,15 @@ Files already downloaded and verified
|
|||
Files already downloaded and verified
|
||||
Files already downloaded and verified
|
||||
Done loading...
|
||||
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
|
||||
WARNING:tensorflow:From
|
||||
/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263:
|
||||
colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
|
||||
Instructions for updating:
|
||||
Colocations handled automatically by placer.
|
||||
Building graph...
|
||||
WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
|
||||
WARNING:tensorflow:From
|
||||
/Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from
|
||||
tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
|
||||
Instructions for updating:
|
||||
Use tf.cast instead.
|
||||
Finished processing loop construction ...
|
||||
|
@ -136,16 +150,36 @@ Model has a total of 7567880 parameters
|
|||
Initializing variables...
|
||||
Start broadcast
|
||||
End broadcast
|
||||
Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff: 0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent: 2.272536277770996, l
|
||||
oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first: 0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818, context_0/res_optim_res_c1/
|
||||
cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10, context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10, context_0/res_optim_res_c2/gb:0: 6.27235652306268
|
||||
3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11, context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437, context_0/res_1_res_c2/g:0
|
||||
: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0: 1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0: 6.868505764145993e-10, context_0/res_2
|
||||
_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0: 8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0: 0.0194275379180908
|
||||
2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10, context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11, context_0/res_3_res_c2/gb:0: 7.75612463
|
||||
1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0: 4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0: 0.34806951880455017, context_0/res_4_res_c2/g:
|
||||
0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0: 5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0: 3.807749116013781e-10, context_0/res_5
|
||||
_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0: 7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039, context_0/fc5/
|
||||
Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff:
|
||||
0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent:
|
||||
2.272536277770996, l
|
||||
oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first:
|
||||
0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818,
|
||||
context_0/res_optim_res_c1/
|
||||
cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10,
|
||||
context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10,
|
||||
context_0/res_optim_res_c2/gb:0: 6.27235652306268
|
||||
3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11,
|
||||
context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437,
|
||||
context_0/res_1_res_c2/g:0
|
||||
: 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0:
|
||||
1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0:
|
||||
6.868505764145993e-10, context_0/res_2
|
||||
_res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0:
|
||||
8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0:
|
||||
0.0194275379180908
|
||||
2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10,
|
||||
context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11,
|
||||
context_0/res_3_res_c2/gb:0: 7.75612463
|
||||
1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0:
|
||||
4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0:
|
||||
0.34806951880455017, context_0/res_4_res_c2/g:
|
||||
0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0:
|
||||
5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0:
|
||||
3.807749116013781e-10, context_0/res_5
|
||||
_res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0:
|
||||
7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039,
|
||||
context_0/fc5/
|
||||
b:0: 0.4506262540817261,
|
||||
|
||||
................................................................................................................................
|
||||
|
@ -159,7 +193,8 @@ Inception score of 1.2397289276123047 with std of 0.0
|
|||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass
|
||||
python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01
|
||||
--zero_kl --replay_batch --cclass
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -169,7 +204,8 @@ python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size
|
|||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path>
|
||||
python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01
|
||||
--replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=<imagenet32x32 path>
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -179,7 +215,8 @@ python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=3
|
|||
<br>
|
||||
|
||||
```
|
||||
python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path>
|
||||
python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass
|
||||
--zero_kl --dataset=imagenetfull --imagenet_datadir=<full imagenet path>
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -197,7 +234,8 @@ python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0
|
|||
python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act
|
||||
```
|
||||
|
||||
* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by different settings of task flag in the file
|
||||
* the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by
|
||||
different settings of task flag in the file
|
||||
* for example, to visualize cross class mappings in cifar-10, you can run:
|
||||
|
||||
<br>
|
||||
|
@ -217,7 +255,8 @@ python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resu
|
|||
<br>
|
||||
|
||||
```
|
||||
python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False
|
||||
python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200
|
||||
--large_model --svhnmix --cclass=False
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -227,7 +266,8 @@ python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_
|
|||
<br>
|
||||
|
||||
```
|
||||
python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd steps> --num_steps=10 --lival=<li bound value> --wider_model
|
||||
python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd=<number of pgd
|
||||
steps> --num_steps=10 --lival=<li bound value> --wider_model
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -236,12 +276,14 @@ python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=
|
|||
|
||||
<br>
|
||||
|
||||
* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in `cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
|
||||
* to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in
|
||||
`cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
|
||||
|
||||
<br>
|
||||
|
||||
```
|
||||
python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass
|
||||
python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act
|
||||
--cond_pos --replay_batch -cclass
|
||||
```
|
||||
|
||||
<br>
|
||||
|
@ -249,5 +291,7 @@ python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps
|
|||
* once models are trained, they can be sampled from jointly by running:
|
||||
|
||||
```
|
||||
python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos> --exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot> --resume_pos=<resume_pos>
|
||||
python ebm_combine.py --task=conceptcombine --exp_size=<exp_size> --exp_shape=<exp_shape> --exp_pos=<exp_pos>
|
||||
--exp_rot=<exp_rot> --resume_size=<resume_size> --resume_shape=<resume_shape> --resume_rot=<resume_rot>
|
||||
--resume_pos=<resume_pos>
|
||||
```
|
||||
|
|
228
EBMs/ais.py
228
EBMs/ais.py
|
@ -1,40 +1,65 @@
|
|||
import tensorflow as tf
|
||||
import math
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from data import Cifar10, DSprites, Mnist
|
||||
from hmc import hmc
|
||||
from models import DspritesNet, MnistNet, ResNet32, ResNet32Large, ResNet32Wider
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader
|
||||
from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
|
||||
from data import Cifar10, Mnist, DSprites
|
||||
from scipy.misc import logsumexp
|
||||
from scipy.misc import imsave
|
||||
from utils import optimistic_restore
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from utils import optimistic_restore
|
||||
|
||||
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
|
||||
flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
|
||||
flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
|
||||
flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
|
||||
flags.DEFINE_string(
|
||||
"dataset", "cifar10", "cifar10 or mnist or dsprites or 2d or toy Gauss"
|
||||
)
|
||||
flags.DEFINE_string(
|
||||
"logdir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_string("exp", "default", "name of experiments")
|
||||
flags.DEFINE_integer(
|
||||
"data_workers", 5, "Number of different data workers to load data in parallel"
|
||||
)
|
||||
flags.DEFINE_integer("batch_size", 16, "Size of inputs")
|
||||
flags.DEFINE_string("resume_iter", "-1", "iteration to resume training from")
|
||||
|
||||
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
||||
flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
|
||||
flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
|
||||
flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
|
||||
flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
|
||||
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
|
||||
flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
|
||||
flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
|
||||
flags.DEFINE_bool(
|
||||
"max_pool",
|
||||
False,
|
||||
"Whether or not to use max pooling rather than strided convolutions",
|
||||
)
|
||||
flags.DEFINE_integer(
|
||||
"num_filters",
|
||||
64,
|
||||
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
|
||||
)
|
||||
flags.DEFINE_integer("pdist", 10, "number of intermediate distributions for ais")
|
||||
flags.DEFINE_integer("gauss_dim", 500, "dimensions for modeling Gaussian")
|
||||
flags.DEFINE_integer(
|
||||
"rescale", 1, "factor to rescale input outside of normal (0, 1) box"
|
||||
)
|
||||
flags.DEFINE_float(
|
||||
"temperature", 1, "temperature at which to compute likelihood of model"
|
||||
)
|
||||
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
|
||||
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
||||
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
||||
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
||||
flags.DEFINE_bool(
|
||||
"cclass",
|
||||
False,
|
||||
"Whether to evaluate the log likelihood of conditional model or not",
|
||||
)
|
||||
flags.DEFINE_bool(
|
||||
"single",
|
||||
False,
|
||||
"Whether to evaluate the log likelihood of conditional model or not",
|
||||
)
|
||||
flags.DEFINE_bool("large_model", False, "Use large model to evaluate")
|
||||
flags.DEFINE_bool("wider_model", False, "Use large model to evaluate")
|
||||
flags.DEFINE_float("alr", 0.0045, "Learning rate to use for HMC steps")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
@ -45,11 +70,12 @@ label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
|
|||
def unscale_im(im):
|
||||
return (255 * np.clip(im, 0, 1)).astype(np.uint8)
|
||||
|
||||
|
||||
def gauss_prob_log(x, prec=1.0):
|
||||
|
||||
nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
|
||||
norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
|
||||
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
|
||||
prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2.0 * prec
|
||||
|
||||
return norm_constant_log + prob_density_log
|
||||
|
||||
|
@ -73,23 +99,36 @@ def model_prob_log(x, e_func, weights, temp):
|
|||
def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
|
||||
|
||||
if FLAGS.dataset == "gauss":
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
|
||||
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(
|
||||
x, prec=FLAGS.temperature
|
||||
)
|
||||
else:
|
||||
norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
|
||||
# Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
|
||||
norm_prob = (1 - alpha) * uniform_prob_log(x) + alpha * model_prob_log(
|
||||
x, e_func, weights, temp
|
||||
)
|
||||
# Add an additional log likelihood penalty so that points outside of (0,
|
||||
# 1) box are *highly* unlikely
|
||||
|
||||
|
||||
if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
|
||||
if FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
|
||||
oob_prob = tf.reduce_sum(
|
||||
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1]
|
||||
)
|
||||
elif FLAGS.dataset == "mnist":
|
||||
oob_prob = tf.reduce_sum(
|
||||
tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis=[1, 2]
|
||||
)
|
||||
else:
|
||||
oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
|
||||
oob_prob = tf.reduce_sum(
|
||||
tf.square(100 * (x - tf.clip_by_value(x, 0.0, FLAGS.rescale))),
|
||||
axis=[1, 2, 3],
|
||||
)
|
||||
|
||||
return -norm_prob + oob_prob
|
||||
|
||||
|
||||
def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
|
||||
def ancestral_sample(
|
||||
e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10
|
||||
):
|
||||
if FLAGS.dataset == "2d":
|
||||
x = tf.placeholder(tf.float32, shape=(None, 2))
|
||||
elif FLAGS.dataset == "gauss":
|
||||
|
@ -130,41 +169,46 @@ def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_
|
|||
def main():
|
||||
|
||||
# Initialize dataset
|
||||
if FLAGS.dataset == 'cifar10':
|
||||
if FLAGS.dataset == "cifar10":
|
||||
dataset = Cifar10(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 3
|
||||
dim_input = 32 * 32 * 3
|
||||
elif FLAGS.dataset == 'imagenet':
|
||||
32 * 32 * 3
|
||||
elif FLAGS.dataset == "imagenet":
|
||||
dataset = ImagenetClass()
|
||||
channel_num = 3
|
||||
dim_input = 64 * 64 * 3
|
||||
elif FLAGS.dataset == 'mnist':
|
||||
64 * 64 * 3
|
||||
elif FLAGS.dataset == "mnist":
|
||||
dataset = Mnist(train=False, rescale=FLAGS.rescale)
|
||||
channel_num = 1
|
||||
dim_input = 28 * 28 * 1
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
28 * 28 * 1
|
||||
elif FLAGS.dataset == "dsprites":
|
||||
dataset = DSprites()
|
||||
channel_num = 1
|
||||
dim_input = 64 * 64 * 1
|
||||
elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
|
||||
64 * 64 * 1
|
||||
elif FLAGS.dataset == "2d" or FLAGS.dataset == "gauss":
|
||||
dataset = Box2D()
|
||||
|
||||
dim_output = 1
|
||||
data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=FLAGS.data_workers,
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
if FLAGS.dataset == 'mnist':
|
||||
if FLAGS.dataset == "mnist":
|
||||
model = MnistNet(num_channels=channel_num)
|
||||
elif FLAGS.dataset == 'cifar10':
|
||||
elif FLAGS.dataset == "cifar10":
|
||||
if FLAGS.large_model:
|
||||
model = ResNet32Large(num_filters=128)
|
||||
elif FLAGS.wider_model:
|
||||
model = ResNet32Wider(num_filters=192)
|
||||
else:
|
||||
model = ResNet32(num_channels=channel_num, num_filters=128)
|
||||
elif FLAGS.dataset == 'dsprites':
|
||||
elif FLAGS.dataset == "dsprites":
|
||||
model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
|
||||
|
||||
weights = model.construct_weights('context_{}'.format(0))
|
||||
weights = model.construct_weights("context_{}".format(0))
|
||||
|
||||
config = tf.ConfigProto()
|
||||
sess = tf.Session(config=config)
|
||||
|
@ -173,8 +217,8 @@ def main():
|
|||
sess.run(tf.global_variables_initializer())
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
|
||||
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
resume_itr = FLAGS.resume_iter
|
||||
model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
|
||||
FLAGS.resume_iter
|
||||
|
||||
if FLAGS.resume_iter != "-1":
|
||||
optimistic_restore(sess, model_file)
|
||||
|
@ -182,14 +226,17 @@ def main():
|
|||
print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
|
||||
# saver.restore(sess, model_file)
|
||||
|
||||
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
|
||||
chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(
|
||||
model, weights, FLAGS.batch_size, temp=FLAGS.temperature
|
||||
)
|
||||
print("Finished constructing ancestral sample ...................")
|
||||
|
||||
if FLAGS.dataset != "gauss":
|
||||
comb_weights_cum = []
|
||||
batch_size = tf.shape(x_init)[0]
|
||||
label_tiled = tf.tile(label_default, (batch_size, 1))
|
||||
e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
|
||||
e_compute = -FLAGS.temperature * model.forward(
|
||||
x_init, weights, label=label_tiled
|
||||
)
|
||||
e_pos_list = []
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(data_loader):
|
||||
|
@ -205,44 +252,75 @@ def main():
|
|||
alr = 0.0085
|
||||
elif FLAGS.dataset == "mnist":
|
||||
alr = 0.0065
|
||||
#90 alr = 0.0035
|
||||
# 90 alr = 0.0035
|
||||
else:
|
||||
# alr = 0.0125
|
||||
if FLAGS.rescale == 8:
|
||||
alr = 0.0085
|
||||
else:
|
||||
alr = 0.0045
|
||||
#
|
||||
#
|
||||
for i in range(1):
|
||||
tot_weight = 0
|
||||
for j in tqdm(range(1, FLAGS.pdist+1)):
|
||||
for j in tqdm(range(1, FLAGS.pdist + 1)):
|
||||
if j == 1:
|
||||
if FLAGS.dataset == "cifar10":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)
|
||||
)
|
||||
elif FLAGS.dataset == "gauss":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)
|
||||
)
|
||||
elif FLAGS.dataset == "mnist":
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)
|
||||
)
|
||||
else:
|
||||
x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
|
||||
x_curr = np.random.uniform(
|
||||
0, FLAGS.rescale, size=(FLAGS.batch_size, 2)
|
||||
)
|
||||
|
||||
alpha_prev = (j-1) / FLAGS.pdist
|
||||
alpha_prev = (j - 1) / FLAGS.pdist
|
||||
alpha_new = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
cweight, x_curr = sess.run(
|
||||
[chain_weights, x],
|
||||
{
|
||||
a_prev: alpha_prev,
|
||||
a_new: alpha_new,
|
||||
x_init: x_curr,
|
||||
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
|
||||
},
|
||||
)
|
||||
tot_weight = tot_weight + cweight
|
||||
|
||||
print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
print(
|
||||
"Total values of lower value based off forward sampling",
|
||||
np.mean(tot_weight),
|
||||
np.std(tot_weight),
|
||||
)
|
||||
|
||||
tot_weight = 0
|
||||
|
||||
for j in tqdm(range(FLAGS.pdist, 0, -1)):
|
||||
alpha_new = (j-1) / FLAGS.pdist
|
||||
alpha_new = (j - 1) / FLAGS.pdist
|
||||
alpha_prev = j / FLAGS.pdist
|
||||
cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
|
||||
cweight, x_curr = sess.run(
|
||||
[chain_weights, x],
|
||||
{
|
||||
a_prev: alpha_prev,
|
||||
a_new: alpha_new,
|
||||
x_init: x_curr,
|
||||
approx_lr: alr * (5 ** (2.5 * -alpha_prev)),
|
||||
},
|
||||
)
|
||||
tot_weight = tot_weight - cweight
|
||||
|
||||
print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
|
||||
|
||||
print(
|
||||
"Total values of upper value based off backward sampling",
|
||||
np.mean(tot_weight),
|
||||
np.std(tot_weight),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -14,223 +14,244 @@
|
|||
# ==============================================================================
|
||||
|
||||
"""Adam for TensorFlow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training import training_ops
|
||||
from tensorflow.python.ops import (
|
||||
control_flow_ops,
|
||||
math_ops,
|
||||
resource_variable_ops,
|
||||
state_ops,
|
||||
)
|
||||
from tensorflow.python.training import optimizer, training_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@tf_export("train.AdamOptimizer")
|
||||
class AdamOptimizer(optimizer.Optimizer):
|
||||
"""Optimizer that implements the Adam algorithm.
|
||||
"""Optimizer that implements the Adam algorithm.
|
||||
|
||||
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
|
||||
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
|
||||
use_locking=False, name="Adam"):
|
||||
"""Construct a new Adam optimizer.
|
||||
|
||||
Initialization:
|
||||
|
||||
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
|
||||
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
|
||||
$$t := 0 \text{(Initialize timestep)}$$
|
||||
|
||||
The update rule for `variable` with gradient `g` uses an optimization
|
||||
described at the end of section2 of the paper:
|
||||
|
||||
$$t := t + 1$$
|
||||
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
|
||||
|
||||
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
|
||||
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
|
||||
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||
|
||||
The default value of 1e-8 for epsilon might not be a good default in
|
||||
general. For example, when training an Inception network on ImageNet a
|
||||
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||
hat" in the paper.
|
||||
|
||||
The sparse implementation of this algorithm (used when the gradient is an
|
||||
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||
lookup in the forward pass) does apply momentum to variable slices even if
|
||||
they were not used in the forward pass (meaning they have a gradient equal
|
||||
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||
behavior (in contrast to some momentum implementations which ignore momentum
|
||||
unless a variable slice was actually used).
|
||||
|
||||
Args:
|
||||
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||
beta1: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 1st moment estimates.
|
||||
beta2: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 2nd moment estimates.
|
||||
epsilon: A small constant for numerical stability. This epsilon is
|
||||
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||
use_locking: If True use locks for update operations.
|
||||
name: Optional name for the operations created when applying gradients.
|
||||
Defaults to "Adam".
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
|
||||
`epsilon` can each be a callable that takes no arguments and returns the
|
||||
actual value to use. This can be useful for changing these values across
|
||||
different invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
|
||||
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
|
||||
"""
|
||||
super(AdamOptimizer, self).__init__(use_locking, name)
|
||||
self._lr = learning_rate
|
||||
self._beta1 = beta1
|
||||
self._beta2 = beta2
|
||||
self._epsilon = epsilon
|
||||
|
||||
# Tensor versions of the constructor arguments, created in _prepare().
|
||||
self._lr_t = None
|
||||
self._beta1_t = None
|
||||
self._beta2_t = None
|
||||
self._epsilon_t = None
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8,
|
||||
use_locking=False,
|
||||
name="Adam",
|
||||
):
|
||||
"""Construct a new Adam optimizer.
|
||||
|
||||
# Created in SparseApply if needed.
|
||||
self._updated_lr = None
|
||||
Initialization:
|
||||
|
||||
def _get_beta_accumulators(self):
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
graph = None
|
||||
else:
|
||||
graph = ops.get_default_graph()
|
||||
return (self._get_non_slot_variable("beta1_power", graph=graph),
|
||||
self._get_non_slot_variable("beta2_power", graph=graph))
|
||||
$$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
|
||||
$$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
|
||||
$$t := 0 \text{(Initialize timestep)}$$
|
||||
|
||||
def _create_slots(self, var_list):
|
||||
# Create the beta1 and beta2 accumulators on the same device as the first
|
||||
# variable. Sort the var_list to make sure this device is consistent across
|
||||
# workers (these need to go on the same PS, otherwise some updates are
|
||||
# silently ignored).
|
||||
first_var = min(var_list, key=lambda x: x.name)
|
||||
self._create_non_slot_variable(initial_value=self._beta1,
|
||||
name="beta1_power",
|
||||
colocate_with=first_var)
|
||||
self._create_non_slot_variable(initial_value=self._beta2,
|
||||
name="beta2_power",
|
||||
colocate_with=first_var)
|
||||
The update rule for `variable` with gradient `g` uses an optimization
|
||||
described at the end of section2 of the paper:
|
||||
|
||||
# Create slots for the first and second moments.
|
||||
for v in var_list:
|
||||
self._zeros_slot(v, "m", self._name)
|
||||
self._zeros_slot(v, "v", self._name)
|
||||
$$t := t + 1$$
|
||||
$$lr_t := \text{learning\\_rate} * \\sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
|
||||
|
||||
def _prepare(self):
|
||||
lr = self._call_if_callable(self._lr)
|
||||
beta1 = self._call_if_callable(self._beta1)
|
||||
beta2 = self._call_if_callable(self._beta2)
|
||||
epsilon = self._call_if_callable(self._epsilon)
|
||||
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
|
||||
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
|
||||
$$variable := variable - lr_t * m_t / (\\sqrt{v_t} + \\epsilon)$$
|
||||
|
||||
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
|
||||
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
|
||||
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
|
||||
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
|
||||
The default value of 1e-8 for epsilon might not be a good default in
|
||||
general. For example, when training an Inception network on ImageNet a
|
||||
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
|
||||
formulation just before Section 2.1 of the Kingma and Ba paper rather than
|
||||
the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
|
||||
hat" in the paper.
|
||||
|
||||
def _apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
The sparse implementation of this algorithm (used when the gradient is an
|
||||
IndexedSlices object, typically because of `tf.gather` or an embedding
|
||||
lookup in the forward pass) does apply momentum to variable slices even if
|
||||
they were not used in the forward pass (meaning they have a gradient equal
|
||||
to zero). Momentum decay (beta1) is also applied to the entire momentum
|
||||
accumulator. This means that the sparse behavior is equivalent to the dense
|
||||
behavior (in contrast to some momentum implementations which ignore momentum
|
||||
unless a variable slice was actually used).
|
||||
|
||||
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
|
||||
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
|
||||
# Clip gradients by 3 std
|
||||
return training_ops.apply_adam(
|
||||
var, m, v,
|
||||
math_ops.cast(beta1_power, var.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, var.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
|
||||
grad, use_locking=self._use_locking).op
|
||||
Args:
|
||||
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||
beta1: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 1st moment estimates.
|
||||
beta2: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 2nd moment estimates.
|
||||
epsilon: A small constant for numerical stability. This epsilon is
|
||||
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||
use_locking: If True use locks for update operations.
|
||||
name: Optional name for the operations created when applying gradients.
|
||||
Defaults to "Adam".
|
||||
|
||||
def _resource_apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
return training_ops.resource_apply_adam(
|
||||
var.handle, m.handle, v.handle,
|
||||
math_ops.cast(beta1_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
|
||||
grad, use_locking=self._use_locking)
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
|
||||
`epsilon` can each be a callable that takes no arguments and returns the
|
||||
actual value to use. This can be useful for changing these values across
|
||||
different invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
super(AdamOptimizer, self).__init__(use_locking, name)
|
||||
self._lr = learning_rate
|
||||
self._beta1 = beta1
|
||||
self._beta2 = beta2
|
||||
self._epsilon = epsilon
|
||||
|
||||
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
|
||||
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
|
||||
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
||||
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
||||
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
||||
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
||||
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
|
||||
# m_t = beta1 * m + (1 - beta1) * g_t
|
||||
m = self.get_slot(var, "m")
|
||||
m_scaled_g_values = grad * (1 - beta1_t)
|
||||
m_t = state_ops.assign(m, m * beta1_t,
|
||||
use_locking=self._use_locking)
|
||||
with ops.control_dependencies([m_t]):
|
||||
m_t = scatter_add(m, indices, m_scaled_g_values)
|
||||
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
||||
v = self.get_slot(var, "v")
|
||||
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
|
||||
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
|
||||
with ops.control_dependencies([v_t]):
|
||||
v_t = scatter_add(v, indices, v_scaled_g_values)
|
||||
v_sqrt = math_ops.sqrt(v_t)
|
||||
var_update = state_ops.assign_sub(var,
|
||||
lr * m_t / (v_sqrt + epsilon_t),
|
||||
use_locking=self._use_locking)
|
||||
return control_flow_ops.group(*[var_update, m_t, v_t])
|
||||
# Tensor versions of the constructor arguments, created in _prepare().
|
||||
self._lr_t = None
|
||||
self._beta1_t = None
|
||||
self._beta2_t = None
|
||||
self._epsilon_t = None
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
return self._apply_sparse_shared(
|
||||
grad.values, var, grad.indices,
|
||||
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
|
||||
x, i, v, use_locking=self._use_locking))
|
||||
# Created in SparseApply if needed.
|
||||
self._updated_lr = None
|
||||
|
||||
def _resource_scatter_add(self, x, i, v):
|
||||
with ops.control_dependencies(
|
||||
[resource_variable_ops.resource_scatter_add(
|
||||
x.handle, i, v)]):
|
||||
return x.value()
|
||||
def _get_beta_accumulators(self):
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
graph = None
|
||||
else:
|
||||
graph = ops.get_default_graph()
|
||||
return (
|
||||
self._get_non_slot_variable("beta1_power", graph=graph),
|
||||
self._get_non_slot_variable("beta2_power", graph=graph),
|
||||
)
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices):
|
||||
return self._apply_sparse_shared(
|
||||
grad, var, indices, self._resource_scatter_add)
|
||||
def _create_slots(self, var_list):
|
||||
# Create the beta1 and beta2 accumulators on the same device as the first
|
||||
# variable. Sort the var_list to make sure this device is consistent across
|
||||
# workers (these need to go on the same PS, otherwise some updates are
|
||||
# silently ignored).
|
||||
first_var = min(var_list, key=lambda x: x.name)
|
||||
self._create_non_slot_variable(
|
||||
initial_value=self._beta1, name="beta1_power", colocate_with=first_var
|
||||
)
|
||||
self._create_non_slot_variable(
|
||||
initial_value=self._beta2, name="beta2_power", colocate_with=first_var
|
||||
)
|
||||
|
||||
def _finish(self, update_ops, name_scope):
|
||||
# Update the power accumulators.
|
||||
with ops.control_dependencies(update_ops):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
with ops.colocate_with(beta1_power):
|
||||
update_beta1 = beta1_power.assign(
|
||||
beta1_power * self._beta1_t, use_locking=self._use_locking)
|
||||
update_beta2 = beta2_power.assign(
|
||||
beta2_power * self._beta2_t, use_locking=self._use_locking)
|
||||
return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
|
||||
name=name_scope)
|
||||
# Create slots for the first and second moments.
|
||||
for v in var_list:
|
||||
self._zeros_slot(v, "m", self._name)
|
||||
self._zeros_slot(v, "v", self._name)
|
||||
|
||||
def _prepare(self):
|
||||
lr = self._call_if_callable(self._lr)
|
||||
beta1 = self._call_if_callable(self._beta1)
|
||||
beta2 = self._call_if_callable(self._beta2)
|
||||
epsilon = self._call_if_callable(self._epsilon)
|
||||
|
||||
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
|
||||
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
|
||||
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
|
||||
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
|
||||
|
||||
def _apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
|
||||
clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
|
||||
grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
|
||||
# Clip gradients by 3 std
|
||||
return training_ops.apply_adam(
|
||||
var,
|
||||
m,
|
||||
v,
|
||||
math_ops.cast(beta1_power, var.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, var.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
|
||||
grad,
|
||||
use_locking=self._use_locking,
|
||||
).op
|
||||
|
||||
def _resource_apply_dense(self, grad, var):
|
||||
m = self.get_slot(var, "m")
|
||||
v = self.get_slot(var, "v")
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
return training_ops.resource_apply_adam(
|
||||
var.handle,
|
||||
m.handle,
|
||||
v.handle,
|
||||
math_ops.cast(beta1_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(beta2_power, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
|
||||
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
|
||||
grad,
|
||||
use_locking=self._use_locking,
|
||||
)
|
||||
|
||||
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
|
||||
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
|
||||
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
||||
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
||||
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
||||
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
||||
lr = lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)
|
||||
# m_t = beta1 * m + (1 - beta1) * g_t
|
||||
m = self.get_slot(var, "m")
|
||||
m_scaled_g_values = grad * (1 - beta1_t)
|
||||
m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
|
||||
with ops.control_dependencies([m_t]):
|
||||
m_t = scatter_add(m, indices, m_scaled_g_values)
|
||||
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
||||
v = self.get_slot(var, "v")
|
||||
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
|
||||
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
|
||||
with ops.control_dependencies([v_t]):
|
||||
v_t = scatter_add(v, indices, v_scaled_g_values)
|
||||
v_sqrt = math_ops.sqrt(v_t)
|
||||
var_update = state_ops.assign_sub(
|
||||
var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking
|
||||
)
|
||||
return control_flow_ops.group(*[var_update, m_t, v_t])
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
return self._apply_sparse_shared(
|
||||
grad.values,
|
||||
var,
|
||||
grad.indices,
|
||||
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
|
||||
x, i, v, use_locking=self._use_locking
|
||||
),
|
||||
)
|
||||
|
||||
def _resource_scatter_add(self, x, i, v):
|
||||
with ops.control_dependencies(
|
||||
[resource_variable_ops.resource_scatter_add(x.handle, i, v)]
|
||||
):
|
||||
return x.value()
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices):
|
||||
return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add)
|
||||
|
||||
def _finish(self, update_ops, name_scope):
|
||||
# Update the power accumulators.
|
||||
with ops.control_dependencies(update_ops):
|
||||
beta1_power, beta2_power = self._get_beta_accumulators()
|
||||
with ops.colocate_with(beta1_power):
|
||||
update_beta1 = beta1_power.assign(
|
||||
beta1_power * self._beta1_t, use_locking=self._use_locking
|
||||
)
|
||||
update_beta2 = beta2_power.assign(
|
||||
beta2_power * self._beta2_t, use_locking=self._use_locking
|
||||
)
|
||||
return control_flow_ops.group(
|
||||
*update_ops + [update_beta1, update_beta2], name=name_scope
|
||||
)
|
||||
|
|
311
EBMs/data.py
311
EBMs/data.py
|
@ -1,42 +1,48 @@
|
|||
from tensorflow.python.platform import flags
|
||||
from tensorflow.contrib.data.python.ops import batching, threadpool
|
||||
import tensorflow as tf
|
||||
import json
|
||||
from torch.utils.data import Dataset
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import os
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import time
|
||||
from scipy.misc import imread, imresize
|
||||
from skimage.color import rgb2grey
|
||||
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
|
||||
from torchvision import transforms
|
||||
from imagenet_preprocessing import ImagenetPreprocessor
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import torchvision
|
||||
from imagenet_preprocessing import ImagenetPreprocessor
|
||||
from scipy.misc import imread, imresize
|
||||
from tensorflow.contrib.data.python.ops import batching
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, ImageFolder
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
ROOT_DIR = "./results"
|
||||
|
||||
# Dataset Options
|
||||
flags.DEFINE_string('dsprites_path',
|
||||
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
|
||||
'path to dsprites characters')
|
||||
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
|
||||
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
|
||||
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
|
||||
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
|
||||
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
|
||||
flags.DEFINE_string(
|
||||
"dsprites_path",
|
||||
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
|
||||
"path to dsprites characters",
|
||||
)
|
||||
flags.DEFINE_string(
|
||||
"imagenet_datadir", "/root/imagenet_big", "whether cutoff should always in image"
|
||||
)
|
||||
flags.DEFINE_bool("dshape_only", False, "fix all factors except for shapes")
|
||||
flags.DEFINE_bool("dpos_only", False, "fix all factors except for positions of shapes")
|
||||
flags.DEFINE_bool("dsize_only", False, "fix all factors except for size of objects")
|
||||
flags.DEFINE_bool("drot_only", False, "fix all factors except for rotation of objects")
|
||||
flags.DEFINE_bool(
|
||||
"dsprites_restrict", False, "fix all factors except for rotation of objects"
|
||||
)
|
||||
flags.DEFINE_string("imagenet_path", "/root/imagenet", "path to imagenet images")
|
||||
|
||||
|
||||
# Data augmentation options
|
||||
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
|
||||
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
|
||||
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
|
||||
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
|
||||
flags.DEFINE_bool("cutout_inside", False, "whether cutoff should always in image")
|
||||
flags.DEFINE_float("cutout_prob", 1.0, "probability of using cutout")
|
||||
flags.DEFINE_integer("cutout_mask_size", 16, "size of cutout")
|
||||
flags.DEFINE_bool("cutout", False, "whether to add cutout regularizer to data")
|
||||
|
||||
|
||||
def cutout(mask_color=(0, 0, 0)):
|
||||
|
@ -91,13 +97,15 @@ class TFImagenetLoader(Dataset):
|
|||
|
||||
self.curr_sample = 0
|
||||
|
||||
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
|
||||
index_path = osp.join(FLAGS.imagenet_datadir, "index.json")
|
||||
with open(index_path) as f:
|
||||
metadata = json.load(f)
|
||||
counts = metadata['record_counts']
|
||||
counts = metadata["record_counts"]
|
||||
|
||||
if split == 'train':
|
||||
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
|
||||
if split == "train":
|
||||
file_names = list(
|
||||
sorted([x for x in counts.keys() if x.startswith("train")])
|
||||
)
|
||||
|
||||
result_records_to_skip = None
|
||||
files = []
|
||||
|
@ -111,30 +119,44 @@ class TFImagenetLoader(Dataset):
|
|||
# Record the number to skip in the first file
|
||||
result_records_to_skip = records_to_skip
|
||||
files.append(filename)
|
||||
records_to_read -= (records_in_file - records_to_skip)
|
||||
records_to_read -= records_in_file - records_to_skip
|
||||
records_to_skip = 0
|
||||
else:
|
||||
break
|
||||
else:
|
||||
files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
|
||||
files = list(
|
||||
sorted([x for x in counts.keys() if x.startswith("validation")])
|
||||
)
|
||||
|
||||
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
|
||||
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
|
||||
preprocess_function = ImagenetPreprocessor(
|
||||
128, dtype=tf.float32, train=False
|
||||
).parse_and_preprocess
|
||||
|
||||
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
|
||||
ds = tf.data.TFRecordDataset.from_generator(
|
||||
lambda: files, output_types=tf.string
|
||||
)
|
||||
ds = ds.apply(tf.data.TFRecordDataset)
|
||||
ds = ds.take(im_length)
|
||||
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
|
||||
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
|
||||
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
|
||||
ds = ds.apply(
|
||||
batching.map_and_batch(
|
||||
map_func=preprocess_function,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_parallel_batches=4,
|
||||
)
|
||||
)
|
||||
ds = ds.prefetch(buffer_size=2)
|
||||
|
||||
ds_iterator = ds.make_initializable_iterator()
|
||||
labels, images = ds_iterator.get_next()
|
||||
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
|
||||
self.images = tf.clip_by_value(
|
||||
images / 256 + tf.random_uniform(tf.shape(images), 0, 1.0 / 256), 0.0, 1.0
|
||||
)
|
||||
self.labels = labels
|
||||
|
||||
config = tf.ConfigProto(device_count = {'GPU': 0})
|
||||
config = tf.ConfigProto(device_count={"GPU": 0})
|
||||
sess = tf.Session(config=config)
|
||||
sess.run(ds_iterator.initializer)
|
||||
|
||||
|
@ -147,11 +169,17 @@ class TFImagenetLoader(Dataset):
|
|||
|
||||
sess = self.sess
|
||||
|
||||
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
|
||||
im_corrupt = np.random.uniform(
|
||||
0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)
|
||||
)
|
||||
label, im = sess.run([self.labels, self.images])
|
||||
im = im * self.rescale
|
||||
label = np.eye(1000)[label.squeeze() - 1]
|
||||
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
|
||||
im, im_corrupt, label = (
|
||||
torch.from_numpy(im),
|
||||
torch.from_numpy(im_corrupt),
|
||||
torch.from_numpy(label),
|
||||
)
|
||||
return im_corrupt, im, label
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -160,6 +188,7 @@ class TFImagenetLoader(Dataset):
|
|||
def __len__(self):
|
||||
return self.im_length
|
||||
|
||||
|
||||
class CelebA(Dataset):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -180,25 +209,18 @@ class CelebA(Dataset):
|
|||
im = imread(path)
|
||||
im = imresize(im, (32, 32))
|
||||
image_size = 32
|
||||
im = im / 255.
|
||||
im = im / 255.0
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0, 1, size=(image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0, 1, size=(image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Cifar10(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
train=True,
|
||||
full=False,
|
||||
augment=False,
|
||||
noise=True,
|
||||
rescale=1.0):
|
||||
def __init__(self, train=True, full=False, augment=False, noise=True, rescale=1.0):
|
||||
|
||||
if augment:
|
||||
transform_list = [
|
||||
|
@ -215,16 +237,10 @@ class Cifar10(Dataset):
|
|||
transform = transforms.ToTensor()
|
||||
|
||||
self.full = full
|
||||
self.data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
self.data = CIFAR10(ROOT_DIR, transform=transform, train=train, download=True)
|
||||
self.test_data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=False,
|
||||
download=True)
|
||||
ROOT_DIR, transform=transform, train=False, download=True
|
||||
)
|
||||
self.one_hot_map = np.eye(10)
|
||||
self.noise = noise
|
||||
self.rescale = rescale
|
||||
|
@ -255,16 +271,18 @@ class Cifar10(Dataset):
|
|||
im = im * 255 / 256
|
||||
|
||||
if self.noise:
|
||||
im = im * self.rescale + \
|
||||
np.random.uniform(0, self.rescale * 1 / 256., im.shape)
|
||||
im = im * self.rescale + np.random.uniform(
|
||||
0, self.rescale * 1 / 256.0, im.shape
|
||||
)
|
||||
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, self.rescale, (image_size, image_size, 3))
|
||||
0.0, self.rescale, (image_size, image_size, 3)
|
||||
)
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
@ -287,10 +305,8 @@ class Cifar100(Dataset):
|
|||
transform = transforms.ToTensor()
|
||||
|
||||
self.data = CIFAR100(
|
||||
"/root/cifar100",
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
"/root/cifar100", transform=transform, train=train, download=True
|
||||
)
|
||||
self.one_hot_map = np.eye(100)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -308,11 +324,10 @@ class Cifar100(Dataset):
|
|||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
@ -340,11 +355,10 @@ class Svhn(Dataset):
|
|||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
@ -352,9 +366,8 @@ class Svhn(Dataset):
|
|||
class Mnist(Dataset):
|
||||
def __init__(self, train=True, rescale=1.0):
|
||||
self.data = MNIST(
|
||||
"/root/mnist",
|
||||
transform=transforms.ToTensor(),
|
||||
download=True, train=train)
|
||||
"/root/mnist", transform=transforms.ToTensor(), download=True, train=train
|
||||
)
|
||||
self.labels = np.eye(10)
|
||||
self.rescale = rescale
|
||||
|
||||
|
@ -367,13 +380,13 @@ class Mnist(Dataset):
|
|||
im = im.squeeze()
|
||||
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
|
||||
# im = im.numpy() / 2 + 0.2
|
||||
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
|
||||
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1.0 / 256, (28, 28))
|
||||
im = im * self.rescale
|
||||
image_size = 28
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
@ -381,54 +394,63 @@ class Mnist(Dataset):
|
|||
|
||||
class DSprites(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
cond_size=False,
|
||||
cond_shape=False,
|
||||
cond_pos=False,
|
||||
cond_rot=False):
|
||||
self, cond_size=False, cond_shape=False, cond_pos=False, cond_rot=False
|
||||
):
|
||||
dat = np.load(FLAGS.dsprites_path)
|
||||
|
||||
if FLAGS.dshape_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
|
||||
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
l = dat["latents_values"]
|
||||
mask = (
|
||||
(l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 2] == 0.5)
|
||||
& (l[:, 3] == 30 * np.pi / 39)
|
||||
)
|
||||
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
|
||||
self.label = self.label[:, 1:2]
|
||||
elif FLAGS.dpos_only:
|
||||
l = dat['latents_values']
|
||||
l = dat["latents_values"]
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 1] == 1) & (
|
||||
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
mask = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (100, 1))
|
||||
self.label = self.label[:, 4:] + 0.5
|
||||
elif FLAGS.dsize_only:
|
||||
l = dat['latents_values']
|
||||
l = dat["latents_values"]
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
self.label = (self.label[:, 2:3])
|
||||
mask = (
|
||||
(l[:, 3] == 30 * np.pi / 39)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 1] == 1)
|
||||
)
|
||||
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
|
||||
self.label = self.label[:, 2:3]
|
||||
elif FLAGS.drot_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
self.label = (self.label[:, 3:4])
|
||||
l = dat["latents_values"]
|
||||
mask = (
|
||||
(l[:, 2] == 0.5)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 1] == 1)
|
||||
)
|
||||
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (100, 1))
|
||||
self.label = self.label[:, 3:4]
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1
|
||||
)
|
||||
elif FLAGS.dsprites_restrict:
|
||||
l = dat['latents_values']
|
||||
l = dat["latents_values"]
|
||||
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
|
||||
|
||||
self.data = dat['imgs'][mask]
|
||||
self.label = dat['latents_values'][mask]
|
||||
self.data = dat["imgs"][mask]
|
||||
self.label = dat["latents_values"][mask]
|
||||
else:
|
||||
self.data = dat['imgs']
|
||||
self.label = dat['latents_values']
|
||||
self.data = dat["imgs"]
|
||||
self.label = dat["latents_values"]
|
||||
|
||||
if cond_size:
|
||||
self.label = self.label[:, 2:3]
|
||||
|
@ -439,7 +461,8 @@ class DSprites(Dataset):
|
|||
elif cond_rot:
|
||||
self.label = self.label[:, 3:4]
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1
|
||||
)
|
||||
else:
|
||||
self.label = self.label[:, 1:2]
|
||||
|
||||
|
@ -452,20 +475,20 @@ class DSprites(Dataset):
|
|||
im = self.data[index]
|
||||
image_size = 64
|
||||
|
||||
if not (
|
||||
FLAGS.dpos_only or FLAGS.dsize_only) and (
|
||||
not FLAGS.cond_size) and (
|
||||
not FLAGS.cond_pos) and (
|
||||
not FLAGS.cond_rot) and (
|
||||
not FLAGS.drot_only):
|
||||
label = self.identity[self.label[index].astype(
|
||||
np.int32) - 1].squeeze()
|
||||
if (
|
||||
not (FLAGS.dpos_only or FLAGS.dsize_only)
|
||||
and (not FLAGS.cond_size)
|
||||
and (not FLAGS.cond_pos)
|
||||
and (not FLAGS.cond_rot)
|
||||
and (not FLAGS.drot_only)
|
||||
):
|
||||
label = self.identity[self.label[index].astype(np.int32) - 1].squeeze()
|
||||
else:
|
||||
label = self.label[index]
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
@ -478,25 +501,20 @@ class Imagenet(Dataset):
|
|||
for i in range(1, 11):
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'train_data_batch_{}'.format(i)),
|
||||
'rb'))
|
||||
osp.join(FLAGS.imagenet_path, "train_data_batch_{}".format(i)),
|
||||
"rb",
|
||||
)
|
||||
)
|
||||
if i == 1:
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
labels = f["labels"]
|
||||
data = f["data"]
|
||||
else:
|
||||
labels.extend(f['labels'])
|
||||
data = np.vstack((data, f['data']))
|
||||
labels.extend(f["labels"])
|
||||
data = np.vstack((data, f["data"]))
|
||||
else:
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'val_data'),
|
||||
'rb'))
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
f = pickle.load(open(osp.join(FLAGS.imagenet_path, "val_data"), "rb"))
|
||||
labels = f["labels"]
|
||||
data = f["data"]
|
||||
|
||||
self.labels = labels
|
||||
self.data = data
|
||||
|
@ -520,11 +538,10 @@ class Imagenet(Dataset):
|
|||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
|
|
@ -1,68 +1,100 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from hmc import hmc
|
||||
from custom_adam import AdamOptimizer
|
||||
from models import DspritesNet
|
||||
from scipy.misc import imsave
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from models import DspritesNet
|
||||
from utils import optimistic_restore, ReplayBuffer
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
from rl_algs.logger import TensorBoardOutputFormat
|
||||
from scipy.misc import imsave
|
||||
import os
|
||||
from custom_adam import AdamOptimizer
|
||||
from tqdm import tqdm
|
||||
from utils import ReplayBuffer
|
||||
|
||||
flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
|
||||
flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
|
||||
flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
|
||||
flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
|
||||
flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
|
||||
flags.DEFINE_bool('cclass', True, 'not cclass')
|
||||
flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results')
|
||||
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
|
||||
flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.')
|
||||
flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape')
|
||||
flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape')
|
||||
flags.DEFINE_integer("batch_size", 256, "Size of inputs")
|
||||
flags.DEFINE_integer("data_workers", 4, "Number of workers to do things")
|
||||
flags.DEFINE_string("logdir", "cachedir", "directory for logging")
|
||||
flags.DEFINE_string(
|
||||
"savedir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_integer(
|
||||
"num_filters",
|
||||
64,
|
||||
"number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.",
|
||||
)
|
||||
flags.DEFINE_float("step_lr", 500, "size of gradient descent size")
|
||||
flags.DEFINE_string(
|
||||
"dsprites_path",
|
||||
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
|
||||
"path to dsprites characters",
|
||||
)
|
||||
flags.DEFINE_bool("cclass", True, "not cclass")
|
||||
flags.DEFINE_bool("proj_cclass", False, "use for backwards compatibility reasons")
|
||||
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
||||
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
||||
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
||||
flags.DEFINE_bool("plot_curve", False, "Generate a curve of results")
|
||||
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
|
||||
flags.DEFINE_string(
|
||||
"task",
|
||||
"conceptcombine",
|
||||
"conceptcombine, labeldiscover, gentest, genbaseline, etc.",
|
||||
)
|
||||
flags.DEFINE_bool("joint_shape", False, "whether to use pos_size or pos_shape")
|
||||
flags.DEFINE_bool("joint_rot", False, "whether to use pos_size or pos_shape")
|
||||
|
||||
# Conditions on which models to use
|
||||
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
|
||||
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
|
||||
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
|
||||
flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale')
|
||||
flags.DEFINE_bool("cond_pos", True, "whether to condition on position")
|
||||
flags.DEFINE_bool("cond_rot", True, "whether to condition on rotation")
|
||||
flags.DEFINE_bool("cond_shape", True, "whether to condition on shape")
|
||||
flags.DEFINE_bool("cond_scale", True, "whether to condition on scale")
|
||||
|
||||
flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments')
|
||||
flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments')
|
||||
flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments')
|
||||
flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments')
|
||||
flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume')
|
||||
flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume')
|
||||
flags.DEFINE_integer('break_steps', 300, 'steps to break')
|
||||
flags.DEFINE_string("exp_size", "dsprites_2018_cond_size", "name of experiments")
|
||||
flags.DEFINE_string("exp_shape", "dsprites_2018_cond_shape", "name of experiments")
|
||||
flags.DEFINE_string("exp_pos", "dsprites_2018_cond_pos_cert", "name of experiments")
|
||||
flags.DEFINE_string("exp_rot", "dsprites_cond_rot_119_00", "name of experiments")
|
||||
flags.DEFINE_integer("resume_size", 169000, "First iteration to resume")
|
||||
flags.DEFINE_integer("resume_shape", 477000, "Second iteration to resume")
|
||||
flags.DEFINE_integer("resume_pos", 8000, "Second iteration to resume")
|
||||
flags.DEFINE_integer("resume_rot", 690000, "Second iteration to resume")
|
||||
flags.DEFINE_integer("break_steps", 300, "steps to break")
|
||||
|
||||
# Whether to train for gentest
|
||||
flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions')
|
||||
flags.DEFINE_bool(
|
||||
"train",
|
||||
False,
|
||||
"whether to train on generalization into multiple different predictions",
|
||||
)
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class DSpritesGen(Dataset):
|
||||
def __init__(self, data, latents, frac=0.0):
|
||||
|
||||
l = latents
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
|
||||
mask_size = (
|
||||
(l[:, 3] == 30 * np.pi / 39)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 2] == 0.5)
|
||||
)
|
||||
elif FLAGS.joint_rot:
|
||||
mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
|
||||
mask_size = (
|
||||
(l[:, 1] == 1)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 2] == 0.5)
|
||||
)
|
||||
else:
|
||||
mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1)
|
||||
mask_size = (
|
||||
(l[:, 3] == 30 * np.pi / 39)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 1] == 1)
|
||||
)
|
||||
|
||||
mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
|
||||
|
@ -80,12 +112,14 @@ class DSpritesGen(Dataset):
|
|||
self.data = np.concatenate((data_pos, data_size), axis=0)
|
||||
self.label = np.concatenate((l_pos, l_size), axis=0)
|
||||
|
||||
mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39))
|
||||
mask_neg = (~(mask_size & mask_pos)) & (
|
||||
(l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)
|
||||
)
|
||||
data_add = data[mask_neg]
|
||||
l_add = l[mask_neg]
|
||||
|
||||
perm_idx = np.random.permutation(data_add.shape[0])
|
||||
select_idx = perm_idx[:int(frac*perm_idx.shape[0])]
|
||||
select_idx = perm_idx[: int(frac * perm_idx.shape[0])]
|
||||
data_add = data_add[select_idx]
|
||||
l_add = l_add[select_idx]
|
||||
|
||||
|
@ -104,7 +138,9 @@ class DSpritesGen(Dataset):
|
|||
if FLAGS.joint_shape:
|
||||
label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
|
||||
elif FLAGS.joint_rot:
|
||||
label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])])
|
||||
label_size = np.array(
|
||||
[np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]
|
||||
)
|
||||
else:
|
||||
label_size = self.label[index, 2:3]
|
||||
|
||||
|
@ -114,14 +150,16 @@ class DSpritesGen(Dataset):
|
|||
|
||||
|
||||
def labeldiscover(sess, kvs, data, latents, save_exp_dir):
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
model_size = kvs['model_size']
|
||||
weight_size = kvs['weight_size']
|
||||
x_mod = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs["LABEL_SIZE"]
|
||||
model_size = kvs["model_size"]
|
||||
weight_size = kvs["weight_size"]
|
||||
x_mod = kvs["X_NOISE"]
|
||||
|
||||
label_output = LABEL_SIZE
|
||||
for i in range(FLAGS.num_steps):
|
||||
label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03)
|
||||
label_output = label_output + tf.random_normal(
|
||||
tf.shape(label_output), mean=0.0, stddev=0.03
|
||||
)
|
||||
e_noise = model_size.forward(x_mod, weight_size, label=label_output)
|
||||
label_grad = tf.gradients(e_noise, [label_output])[0]
|
||||
# label_grad = tf.Print(label_grad, [label_grad])
|
||||
|
@ -130,13 +168,13 @@ def labeldiscover(sess, kvs, data, latents, save_exp_dir):
|
|||
|
||||
diffs = []
|
||||
for i in range(30):
|
||||
s = i*FLAGS.batch_size
|
||||
d = (i+1)*FLAGS.batch_size
|
||||
s = i * FLAGS.batch_size
|
||||
d = (i + 1) * FLAGS.batch_size
|
||||
data_i = data[s:d]
|
||||
latent_i = latents[s:d]
|
||||
latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
|
||||
|
||||
feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init}
|
||||
feed_dict = {x_mod: data_i, LABEL_SIZE: latent_init}
|
||||
size_pred = sess.run([label_output], feed_dict)[0]
|
||||
size_gt = latent_i[:, 2:3]
|
||||
|
||||
|
@ -155,9 +193,11 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
|
||||
LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
|
||||
|
||||
weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))
|
||||
weights_baseline = model_baseline.construct_weights(
|
||||
"context_baseline_{}".format(frac)
|
||||
)
|
||||
|
||||
X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
|
||||
X_feed = tf.placeholder(shape=(None, 2 * FLAGS.num_filters), dtype=tf.float32)
|
||||
X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
|
||||
X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
|
||||
|
@ -168,14 +208,20 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
gvs = [(k, v) for (k, v) in gvs if k is not None]
|
||||
train_op = optimizer.apply_gradients(gvs)
|
||||
|
||||
dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
|
||||
dataloader = DataLoader(
|
||||
DSpritesGen(data, latents, frac=frac),
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=6,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
datafull = data
|
||||
|
||||
itr = 0
|
||||
saver = tf.train.Saver()
|
||||
|
||||
vs = optimizer.variables()
|
||||
optimizer.variables()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
if FLAGS.train:
|
||||
|
@ -185,7 +231,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
data_corrupt = data_corrupt.numpy()
|
||||
label_size, label_pos = label_size.numpy(), label_pos.numpy()
|
||||
|
||||
data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
|
||||
data_corrupt = np.random.randn(
|
||||
data_corrupt.shape[0], 2 * FLAGS.num_filters
|
||||
)
|
||||
label_comb = np.concatenate([label_size, label_pos], axis=1)
|
||||
|
||||
feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
|
||||
|
@ -196,23 +244,27 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
|
||||
itr += 1
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))
|
||||
saver.save(sess, osp.join(save_exp_dir, "model_genbaseline"))
|
||||
|
||||
saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))
|
||||
saver.restore(sess, osp.join(save_exp_dir, "model_genbaseline"))
|
||||
|
||||
l = latents
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
|
||||
else:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
|
||||
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
|
||||
)
|
||||
|
||||
data_gen = datafull[mask_gen]
|
||||
latents_gen = latents[mask_gen]
|
||||
losses = []
|
||||
|
||||
for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
|
||||
data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)
|
||||
for dat, latent in zip(
|
||||
np.array_split(data_gen, 10), np.array_split(latents_gen, 10)
|
||||
):
|
||||
data_init = np.random.randn(dat.shape[0], 2 * FLAGS.num_filters)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
|
||||
|
@ -220,16 +272,19 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
latent = np.concatenate([latent_size, latent_pos], axis=1)
|
||||
feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
|
||||
else:
|
||||
feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
|
||||
feed_dict = {X_feed: data_init, LABEL: latent[:, [2, 4, 5]], X_label: dat}
|
||||
loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
|
||||
# print(loss)
|
||||
losses.append(loss)
|
||||
|
||||
print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))
|
||||
|
||||
print(
|
||||
"Overall MSE for generalization of {} for fraction of {}".format(
|
||||
np.mean(losses), frac
|
||||
)
|
||||
)
|
||||
|
||||
data_try = data_gen[:10]
|
||||
data_init = np.random.randn(10, 2*FLAGS.num_filters)
|
||||
data_init = np.random.randn(10, 2 * FLAGS.num_filters)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
|
||||
|
@ -252,7 +307,9 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
||||
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
|
||||
-1, 66 * 2
|
||||
)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
@ -261,38 +318,40 @@ def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
|
|||
|
||||
|
||||
def gentest(sess, kvs, data, latents, save_exp_dir):
|
||||
X_NOISE = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
LABEL_SHAPE = kvs['LABEL_SHAPE']
|
||||
LABEL_POS = kvs['LABEL_POS']
|
||||
LABEL_ROT = kvs['LABEL_ROT']
|
||||
model_size = kvs['model_size']
|
||||
model_shape = kvs['model_shape']
|
||||
model_pos = kvs['model_pos']
|
||||
model_rot = kvs['model_rot']
|
||||
weight_size = kvs['weight_size']
|
||||
weight_shape = kvs['weight_shape']
|
||||
weight_pos = kvs['weight_pos']
|
||||
weight_rot = kvs['weight_rot']
|
||||
X_NOISE = kvs["X_NOISE"]
|
||||
LABEL_SIZE = kvs["LABEL_SIZE"]
|
||||
LABEL_SHAPE = kvs["LABEL_SHAPE"]
|
||||
LABEL_POS = kvs["LABEL_POS"]
|
||||
LABEL_ROT = kvs["LABEL_ROT"]
|
||||
model_size = kvs["model_size"]
|
||||
model_shape = kvs["model_shape"]
|
||||
model_pos = kvs["model_pos"]
|
||||
model_rot = kvs["model_rot"]
|
||||
weight_size = kvs["weight_size"]
|
||||
weight_shape = kvs["weight_shape"]
|
||||
weight_pos = kvs["weight_pos"]
|
||||
weight_rot = kvs["weight_rot"]
|
||||
X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
|
||||
|
||||
datafull = data
|
||||
# Test combination of generalization where we use slices of both training
|
||||
x_final = X_NOISE
|
||||
x_mod_size = X_NOISE
|
||||
x_mod_pos = X_NOISE
|
||||
|
||||
for i in range(FLAGS.num_steps):
|
||||
|
||||
# use cond_pos
|
||||
|
||||
energies = []
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(
|
||||
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
|
||||
)
|
||||
e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
|
||||
|
||||
# energies.append(e_noise)
|
||||
x_grad = tf.gradients(e_noise, [x_final])[0]
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
|
||||
x_mod_pos = x_mod_pos + tf.random_normal(
|
||||
tf.shape(x_mod_pos), mean=0.0, stddev=0.005
|
||||
)
|
||||
x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
|
||||
x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
|
||||
|
||||
|
@ -332,42 +391,70 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
x_mod = x_mod_pos
|
||||
x_final = x_mod
|
||||
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
loss_kl = model_shape.forward(
|
||||
x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True
|
||||
) + model_pos.forward(
|
||||
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
||||
)
|
||||
|
||||
energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_pos = model_shape.forward(
|
||||
X, weight_shape, reuse=True, label=LABEL_SHAPE
|
||||
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_neg = model_shape.forward(
|
||||
tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE
|
||||
) + model_pos.forward(
|
||||
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
||||
)
|
||||
elif FLAGS.joint_rot:
|
||||
loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
loss_kl = model_rot.forward(
|
||||
x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True
|
||||
) + model_pos.forward(
|
||||
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
||||
)
|
||||
|
||||
energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_pos = model_rot.forward(
|
||||
X, weight_rot, reuse=True, label=LABEL_ROT
|
||||
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_neg = model_rot.forward(
|
||||
tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT
|
||||
) + model_pos.forward(
|
||||
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
||||
)
|
||||
else:
|
||||
loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
|
||||
model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
|
||||
loss_kl = model_size.forward(
|
||||
x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True
|
||||
) + model_pos.forward(
|
||||
x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True
|
||||
)
|
||||
|
||||
energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
|
||||
model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_pos = model_size.forward(
|
||||
X, weight_size, reuse=True, label=LABEL_SIZE
|
||||
) + model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
|
||||
|
||||
energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
|
||||
model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
|
||||
energy_neg = model_size.forward(
|
||||
tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE
|
||||
) + model_pos.forward(
|
||||
tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS
|
||||
)
|
||||
|
||||
energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
|
||||
energy_neg_reduced = energy_neg - tf.reduce_min(energy_neg)
|
||||
coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
|
||||
norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
|
||||
neg_loss = coeff * (-1*energy_neg) / norm_constant
|
||||
coeff * (-1 * energy_neg) / norm_constant
|
||||
|
||||
loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
|
||||
loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
|
||||
loss_total = (
|
||||
loss_ml
|
||||
+ tf.reduce_mean(loss_kl)
|
||||
+ 1
|
||||
* (
|
||||
tf.reduce_mean(tf.square(energy_pos))
|
||||
+ tf.reduce_mean(tf.square(energy_neg))
|
||||
)
|
||||
)
|
||||
|
||||
optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
|
||||
gvs = optimizer.compute_gradients(loss_total)
|
||||
|
@ -377,7 +464,13 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
vs = optimizer.variables()
|
||||
sess.run(tf.variables_initializer(vs))
|
||||
|
||||
dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
|
||||
dataloader = DataLoader(
|
||||
DSpritesGen(data, latents),
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=6,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
x_off = tf.reduce_mean(tf.square(x_mod - X))
|
||||
|
||||
|
@ -385,12 +478,10 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
saver = tf.train.Saver()
|
||||
x_mod = None
|
||||
|
||||
|
||||
if FLAGS.train:
|
||||
replay_buffer = ReplayBuffer(10000)
|
||||
for _ in range(1):
|
||||
|
||||
|
||||
for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
|
||||
data_corrupt = data_corrupt.numpy()[:, :, :]
|
||||
data = data.numpy()[:, :, :]
|
||||
|
@ -398,29 +489,50 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
if x_mod is not None:
|
||||
replay_buffer.add(x_mod)
|
||||
replay_batch = replay_buffer.sample(FLAGS.batch_size)
|
||||
replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
|
||||
replay_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95
|
||||
data_corrupt[replay_mask] = replay_batch[replay_mask]
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_corrupt,
|
||||
X: data,
|
||||
LABEL_SHAPE: label_size,
|
||||
LABEL_POS: label_pos,
|
||||
}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_corrupt,
|
||||
X: data,
|
||||
LABEL_ROT: label_size,
|
||||
LABEL_POS: label_pos,
|
||||
}
|
||||
else:
|
||||
feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_corrupt,
|
||||
X: data,
|
||||
LABEL_SIZE: label_size,
|
||||
LABEL_POS: label_pos,
|
||||
}
|
||||
|
||||
_, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
|
||||
_, off_value, e_pos, e_neg, x_mod = sess.run(
|
||||
[train_op, x_off, energy_pos, energy_neg, x_final],
|
||||
feed_dict=feed_dict,
|
||||
)
|
||||
itr += 1
|
||||
|
||||
if itr % 10 == 0:
|
||||
print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))
|
||||
print(
|
||||
"x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(
|
||||
off_value, e_pos.mean(), e_neg.mean(), itr
|
||||
)
|
||||
)
|
||||
|
||||
if itr == FLAGS.break_steps:
|
||||
break
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, "model_gentest"))
|
||||
|
||||
saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))
|
||||
|
||||
saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))
|
||||
saver.restore(sess, osp.join(save_exp_dir, "model_gentest"))
|
||||
|
||||
l = latents
|
||||
|
||||
|
@ -429,22 +541,43 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
elif FLAGS.joint_rot:
|
||||
mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
|
||||
else:
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
|
||||
mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (
|
||||
~((l[:, 2] == 0.5) | ((l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 31)))
|
||||
)
|
||||
|
||||
data_gen = datafull[mask_gen]
|
||||
latents_gen = latents[mask_gen]
|
||||
|
||||
losses = []
|
||||
|
||||
for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
|
||||
for dat, latent in zip(
|
||||
np.array_split(data_gen, 120), np.array_split(latents_gen, 120)
|
||||
):
|
||||
x = 0.5 + np.random.randn(*dat.shape)
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
feed_dict = {
|
||||
LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1],
|
||||
LABEL_POS: latent[:, 4:],
|
||||
X_NOISE: x,
|
||||
X: dat,
|
||||
}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
feed_dict = {
|
||||
LABEL_ROT: np.concatenate(
|
||||
[np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1
|
||||
),
|
||||
LABEL_POS: latent[:, 4:],
|
||||
X_NOISE: x,
|
||||
X: dat,
|
||||
}
|
||||
else:
|
||||
feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
|
||||
feed_dict = {
|
||||
LABEL_SIZE: latent[:, 2:3],
|
||||
LABEL_POS: latent[:, 4:],
|
||||
X_NOISE: x,
|
||||
X: dat,
|
||||
}
|
||||
|
||||
for i in range(2):
|
||||
x = sess.run([x_final], feed_dict=feed_dict)[0]
|
||||
|
@ -461,11 +594,25 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
latent_pos = latents_gen[:10, 4:]
|
||||
|
||||
if FLAGS.joint_shape:
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_init,
|
||||
LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32) - 1],
|
||||
LABEL_POS: latent_pos,
|
||||
}
|
||||
elif FLAGS.joint_rot:
|
||||
feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
|
||||
feed_dict = {
|
||||
LABEL_ROT: np.concatenate(
|
||||
[np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1
|
||||
),
|
||||
LABEL_POS: latent[:10, 4:],
|
||||
X_NOISE: data_init,
|
||||
}
|
||||
else:
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}
|
||||
feed_dict = {
|
||||
X_NOISE: data_init,
|
||||
LABEL_SIZE: latent_scale,
|
||||
LABEL_POS: latent_pos,
|
||||
}
|
||||
|
||||
x_output = sess.run([x_final], feed_dict=feed_dict)[0]
|
||||
|
||||
|
@ -480,27 +627,28 @@ def gentest(sess, kvs, data, latents, save_exp_dir):
|
|||
x_output_wrap[:, 1:-1, 1:-1] = x_output
|
||||
data_try_wrap[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
|
||||
im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(
|
||||
-1, 66 * 2
|
||||
)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
||||
|
||||
|
||||
def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
||||
X_NOISE = kvs['X_NOISE']
|
||||
LABEL_SIZE = kvs['LABEL_SIZE']
|
||||
LABEL_SHAPE = kvs['LABEL_SHAPE']
|
||||
LABEL_POS = kvs['LABEL_POS']
|
||||
LABEL_ROT = kvs['LABEL_ROT']
|
||||
model_size = kvs['model_size']
|
||||
model_shape = kvs['model_shape']
|
||||
model_pos = kvs['model_pos']
|
||||
model_rot = kvs['model_rot']
|
||||
weight_size = kvs['weight_size']
|
||||
weight_shape = kvs['weight_shape']
|
||||
weight_pos = kvs['weight_pos']
|
||||
weight_rot = kvs['weight_rot']
|
||||
X_NOISE = kvs["X_NOISE"]
|
||||
LABEL_SIZE = kvs["LABEL_SIZE"]
|
||||
LABEL_SHAPE = kvs["LABEL_SHAPE"]
|
||||
LABEL_POS = kvs["LABEL_POS"]
|
||||
LABEL_ROT = kvs["LABEL_ROT"]
|
||||
model_size = kvs["model_size"]
|
||||
model_shape = kvs["model_shape"]
|
||||
model_pos = kvs["model_pos"]
|
||||
model_rot = kvs["model_rot"]
|
||||
weight_size = kvs["weight_size"]
|
||||
weight_shape = kvs["weight_shape"]
|
||||
weight_pos = kvs["weight_pos"]
|
||||
weight_rot = kvs["weight_rot"]
|
||||
|
||||
x_mod = X_NOISE
|
||||
for i in range(FLAGS.num_steps):
|
||||
|
@ -540,13 +688,18 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
|||
data_try = data[:10]
|
||||
data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
|
||||
label_scale = latents[:10, 2:3]
|
||||
label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)]
|
||||
label_shape = np.eye(3)[(latents[:10, 1] - 1).astype(np.uint8)]
|
||||
label_rot = latents[:10, 3:4]
|
||||
label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
|
||||
label_pos = latents[:10, 4:]
|
||||
|
||||
feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos,
|
||||
LABEL_ROT: label_rot}
|
||||
feed_dict = {
|
||||
X_NOISE: data_init,
|
||||
LABEL_SIZE: label_scale,
|
||||
LABEL_SHAPE: label_shape,
|
||||
LABEL_POS: label_pos,
|
||||
LABEL_ROT: label_rot,
|
||||
}
|
||||
x_out = sess.run([x_final], feed_dict)[0]
|
||||
|
||||
im_name = "im"
|
||||
|
@ -569,14 +722,15 @@ def conceptcombine(sess, kvs, data, latents, save_exp_dir):
|
|||
x_out_pad[:, 1:-1, 1:-1] = x_out
|
||||
data_try_pad[:, 1:-1, 1:-1] = data_try
|
||||
|
||||
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2)
|
||||
im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66 * 2)
|
||||
impath = osp.join(save_exp_dir, im_name)
|
||||
imsave(impath, im_output)
|
||||
print("Successfully saved images at {}".format(impath))
|
||||
|
||||
|
||||
def main():
|
||||
data = np.load(FLAGS.dsprites_path)['imgs']
|
||||
l = latents = np.load(FLAGS.dsprites_path)['latents_values']
|
||||
data = np.load(FLAGS.dsprites_path)["imgs"]
|
||||
l = latents = np.load(FLAGS.dsprites_path)["latents_values"]
|
||||
|
||||
np.random.seed(1)
|
||||
idx = np.random.permutation(data.shape[0])
|
||||
|
@ -589,52 +743,74 @@ def main():
|
|||
|
||||
# Model 1 will be conditioned on size
|
||||
model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
|
||||
weight_size = model_size.construct_weights('context_0')
|
||||
weight_size = model_size.construct_weights("context_0")
|
||||
|
||||
# Model 2 will be conditioned on shape
|
||||
model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
|
||||
weight_shape = model_shape.construct_weights('context_1')
|
||||
weight_shape = model_shape.construct_weights("context_1")
|
||||
|
||||
# Model 3 will be conditioned on position
|
||||
model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
|
||||
weight_pos = model_pos.construct_weights('context_2')
|
||||
weight_pos = model_pos.construct_weights("context_2")
|
||||
|
||||
# Model 4 will be conditioned on rotation
|
||||
model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
|
||||
weight_rot = model_rot.construct_weights('context_3')
|
||||
weight_rot = model_rot.construct_weights("context_3")
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))
|
||||
save_path_size = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_size, "model_{}".format(FLAGS.resume_size)
|
||||
)
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
|
||||
v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(0)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(0), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
|
||||
if FLAGS.cond_scale:
|
||||
saver = tf.train.Saver(v_map)
|
||||
saver.restore(sess, save_path_size)
|
||||
|
||||
save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))
|
||||
save_path_shape = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_shape, "model_{}".format(FLAGS.resume_shape)
|
||||
)
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
|
||||
v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(1)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(1), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
|
||||
if FLAGS.cond_shape:
|
||||
saver = tf.train.Saver(v_map)
|
||||
saver.restore(sess, save_path_shape)
|
||||
|
||||
|
||||
save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
|
||||
v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
|
||||
save_path_pos = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_pos, "model_{}".format(FLAGS.resume_pos)
|
||||
)
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(2)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(2), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
saver = tf.train.Saver(v_map)
|
||||
|
||||
if FLAGS.cond_pos:
|
||||
saver.restore(sess, save_path_pos)
|
||||
|
||||
|
||||
save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
|
||||
v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
|
||||
save_path_rot = osp.join(
|
||||
FLAGS.logdir, FLAGS.exp_rot, "model_{}".format(FLAGS.resume_rot)
|
||||
)
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(3)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(3), "context_0")[:-2]): v for v in v_list
|
||||
}
|
||||
saver = tf.train.Saver(v_map)
|
||||
|
||||
if FLAGS.cond_rot:
|
||||
|
@ -646,53 +822,57 @@ def main():
|
|||
LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
|
||||
|
||||
x_mod = X_NOISE
|
||||
|
||||
kvs = {}
|
||||
kvs['X_NOISE'] = X_NOISE
|
||||
kvs['LABEL_SIZE'] = LABEL_SIZE
|
||||
kvs['LABEL_SHAPE'] = LABEL_SHAPE
|
||||
kvs['LABEL_POS'] = LABEL_POS
|
||||
kvs['LABEL_ROT'] = LABEL_ROT
|
||||
kvs['model_size'] = model_size
|
||||
kvs['model_shape'] = model_shape
|
||||
kvs['model_pos'] = model_pos
|
||||
kvs['model_rot'] = model_rot
|
||||
kvs['weight_size'] = weight_size
|
||||
kvs['weight_shape'] = weight_shape
|
||||
kvs['weight_pos'] = weight_pos
|
||||
kvs['weight_rot'] = weight_rot
|
||||
kvs["X_NOISE"] = X_NOISE
|
||||
kvs["LABEL_SIZE"] = LABEL_SIZE
|
||||
kvs["LABEL_SHAPE"] = LABEL_SHAPE
|
||||
kvs["LABEL_POS"] = LABEL_POS
|
||||
kvs["LABEL_ROT"] = LABEL_ROT
|
||||
kvs["model_size"] = model_size
|
||||
kvs["model_shape"] = model_shape
|
||||
kvs["model_pos"] = model_pos
|
||||
kvs["model_rot"] = model_rot
|
||||
kvs["weight_size"] = weight_size
|
||||
kvs["weight_shape"] = weight_shape
|
||||
kvs["weight_pos"] = weight_pos
|
||||
kvs["weight_rot"] = weight_rot
|
||||
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
|
||||
save_exp_dir = osp.join(
|
||||
FLAGS.savedir, "{}_{}_joint".format(FLAGS.exp_size, FLAGS.exp_shape)
|
||||
)
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
|
||||
if FLAGS.task == 'conceptcombine':
|
||||
if FLAGS.task == "conceptcombine":
|
||||
conceptcombine(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'labeldiscover':
|
||||
elif FLAGS.task == "labeldiscover":
|
||||
labeldiscover(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'gentest':
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
|
||||
elif FLAGS.task == "gentest":
|
||||
save_exp_dir = osp.join(
|
||||
FLAGS.savedir, "{}_{}_gen".format(FLAGS.exp_size, FLAGS.exp_pos)
|
||||
)
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
gentest(sess, kvs, data, latents, save_exp_dir)
|
||||
elif FLAGS.task == 'genbaseline':
|
||||
save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
|
||||
elif FLAGS.task == "genbaseline":
|
||||
save_exp_dir = osp.join(
|
||||
FLAGS.savedir, "{}_{}_gen_baseline".format(FLAGS.exp_size, FLAGS.exp_pos)
|
||||
)
|
||||
if not osp.exists(save_exp_dir):
|
||||
os.makedirs(save_exp_dir)
|
||||
|
||||
if FLAGS.plot_curve:
|
||||
mse_losses = []
|
||||
for frac in [i/10 for i in range(11)]:
|
||||
mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
|
||||
for frac in [i / 10 for i in range(11)]:
|
||||
mse_loss = genbaseline(
|
||||
sess, kvs, data, latents, save_exp_dir, frac=frac
|
||||
)
|
||||
mse_losses.append(mse_loss)
|
||||
np.save("mse_baseline_comb.npy", mse_losses)
|
||||
else:
|
||||
genbaseline(sess, kvs, data, latents, save_exp_dir)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
File diff suppressed because it is too large
Load diff
214
EBMs/fid.py
214
EBMs/fid.py
|
@ -1,5 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
|
||||
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs.
|
||||
|
||||
The FID metric calculates the distance between two distributions of images.
|
||||
Typically, we have summary statistics (mean & covariance matrix) of one
|
||||
|
@ -14,28 +14,33 @@ the pool_3 layer of the inception net for generated samples and real world
|
|||
samples respectivly.
|
||||
|
||||
See --help to see further details.
|
||||
'''
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import gzip, pickle
|
||||
import tensorflow as tf
|
||||
from scipy.misc import imread
|
||||
from scipy import linalg
|
||||
import pathlib
|
||||
import urllib
|
||||
import tarfile
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
MODEL_DIR = '/tmp/imagenet'
|
||||
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from scipy import linalg
|
||||
from scipy.misc import imread
|
||||
|
||||
MODEL_DIR = "/tmp/imagenet"
|
||||
DATA_URL = (
|
||||
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
|
||||
)
|
||||
pool3 = None
|
||||
|
||||
|
||||
class InvalidFIDException(Exception):
|
||||
pass
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
def get_fid_score(images, images_gt):
|
||||
images = np.stack(images, 0)
|
||||
images_gt = np.stack(images_gt, 0)
|
||||
|
@ -52,34 +57,38 @@ def get_fid_score(images, images_gt):
|
|||
def create_inception_graph(pth):
|
||||
"""Creates a graph from saved GraphDef file."""
|
||||
# Creates graph from saved graph_def.pb.
|
||||
with tf.gfile.FastGFile( pth, 'rb') as f:
|
||||
with tf.gfile.FastGFile(pth, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString( f.read())
|
||||
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
|
||||
#-------------------------------------------------------------------------------
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name="FID_Inception_Net")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# code for handling inception net derived from
|
||||
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
|
||||
def _get_inception_layer(sess):
|
||||
"""Prepares inception net for batched usage and returns pool_3 layer. """
|
||||
layername = 'FID_Inception_Net/pool_3:0'
|
||||
"""Prepares inception net for batched usage and returns pool_3 layer."""
|
||||
layername = "FID_Inception_Net/pool_3:0"
|
||||
pool3 = sess.graph.get_tensor_by_name(layername)
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
||||
return pool3
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_activations(images, sess, batch_size=50, verbose=False):
|
||||
|
@ -100,23 +109,27 @@ def get_activations(images, sess, batch_size=50, verbose=False):
|
|||
# inception_layer = _get_inception_layer(sess)
|
||||
d0 = images.shape[0]
|
||||
if batch_size > d0:
|
||||
print("warning: batch size is bigger than the data size. setting batch size to data size")
|
||||
print(
|
||||
"warning: batch size is bigger than the data size. setting batch size to data size"
|
||||
)
|
||||
batch_size = d0
|
||||
n_batches = d0//batch_size
|
||||
n_used_imgs = n_batches*batch_size
|
||||
pred_arr = np.empty((n_used_imgs,2048))
|
||||
n_batches = d0 // batch_size
|
||||
n_used_imgs = n_batches * batch_size
|
||||
pred_arr = np.empty((n_used_imgs, 2048))
|
||||
for i in range(n_batches):
|
||||
if verbose:
|
||||
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
|
||||
start = i*batch_size
|
||||
print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True)
|
||||
start = i * batch_size
|
||||
end = start + batch_size
|
||||
batch = images[start:end]
|
||||
pred = sess.run(pool3, {'ExpandDims:0': batch})
|
||||
pred_arr[start:end] = pred.reshape(batch_size,-1)
|
||||
pred = sess.run(pool3, {"ExpandDims:0": batch})
|
||||
pred_arr[start:end] = pred.reshape(batch_size, -1)
|
||||
if verbose:
|
||||
print(" done")
|
||||
return pred_arr
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
||||
|
@ -147,15 +160,22 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
|||
sigma1 = np.atleast_2d(sigma1)
|
||||
sigma2 = np.atleast_2d(sigma2)
|
||||
|
||||
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
|
||||
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
|
||||
assert (
|
||||
mu1.shape == mu2.shape
|
||||
), "Training and test mean vectors have different lengths"
|
||||
assert (
|
||||
sigma1.shape == sigma2.shape
|
||||
), "Training and test covariances have different dimensions"
|
||||
|
||||
diff = mu1 - mu2
|
||||
|
||||
# product might be almost singular
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
|
||||
msg = (
|
||||
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
||||
% eps
|
||||
)
|
||||
warnings.warn(msg)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
||||
|
@ -170,7 +190,9 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
|||
tr_covmean = np.trace(covmean)
|
||||
|
||||
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
|
||||
|
@ -193,47 +215,52 @@ def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
|
|||
mu = np.mean(act, axis=0)
|
||||
sigma = np.cov(act, rowvar=False)
|
||||
return mu, sigma
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# The following functions aren't needed for calculating the FID
|
||||
# they're just here to make this module work as a stand-alone script
|
||||
# for calculating FID scores
|
||||
#-------------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------------
|
||||
def check_or_download_inception(inception_path):
|
||||
''' Checks if the path to the inception file is valid, or downloads
|
||||
the file if it is not present. '''
|
||||
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
"""Checks if the path to the inception file is valid, or downloads
|
||||
the file if it is not present."""
|
||||
INCEPTION_URL = (
|
||||
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
|
||||
)
|
||||
if inception_path is None:
|
||||
inception_path = '/tmp'
|
||||
inception_path = "/tmp"
|
||||
inception_path = pathlib.Path(inception_path)
|
||||
model_file = inception_path / 'classify_image_graph_def.pb'
|
||||
model_file = inception_path / "classify_image_graph_def.pb"
|
||||
if not model_file.exists():
|
||||
print("Downloading Inception model")
|
||||
from urllib import request
|
||||
import tarfile
|
||||
from urllib import request
|
||||
|
||||
fn, _ = request.urlretrieve(INCEPTION_URL)
|
||||
with tarfile.open(fn, mode='r') as f:
|
||||
f.extract('classify_image_graph_def.pb', str(model_file.parent))
|
||||
with tarfile.open(fn, mode="r") as f:
|
||||
f.extract("classify_image_graph_def.pb", str(model_file.parent))
|
||||
return str(model_file)
|
||||
|
||||
|
||||
def _handle_path(path, sess):
|
||||
if path.endswith('.npz'):
|
||||
if path.endswith(".npz"):
|
||||
f = np.load(path)
|
||||
m, s = f['mu'][:], f['sigma'][:]
|
||||
m, s = f["mu"][:], f["sigma"][:]
|
||||
f.close()
|
||||
else:
|
||||
path = pathlib.Path(path)
|
||||
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
|
||||
files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
|
||||
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
|
||||
m, s = calculate_activation_statistics(x, sess)
|
||||
return m, s
|
||||
|
||||
|
||||
def calculate_fid_given_paths(paths, inception_path):
|
||||
''' Calculates the FID of two paths. '''
|
||||
"""Calculates the FID of two paths."""
|
||||
inception_path = check_or_download_inception(inception_path)
|
||||
|
||||
for p in paths:
|
||||
|
@ -250,43 +277,48 @@ def calculate_fid_given_paths(paths, inception_path):
|
|||
|
||||
|
||||
def _init_inception():
|
||||
global pool3
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split('/')[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
||||
filename, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(os.path.join(
|
||||
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name='')
|
||||
# Works with an arbitrary minibatch size.
|
||||
with tf.Session() as sess:
|
||||
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
||||
global pool3
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split("/")[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write(
|
||||
"\r>> Downloading %s %.1f%%"
|
||||
% (filename, float(count * block_size) / float(total_size) * 100.0)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
|
||||
tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(
|
||||
os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
|
||||
) as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name="")
|
||||
# Works with an arbitrary minibatch size.
|
||||
with tf.Session() as sess:
|
||||
pool3 = sess.graph.get_tensor_by_name("pool_3:0")
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
||||
|
||||
|
||||
if pool3 is None:
|
||||
_init_inception()
|
||||
_init_inception()
|
||||
|
|
42
EBMs/hmc.py
42
EBMs/hmc.py
|
@ -1,11 +1,11 @@
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.platform import flags
|
||||
flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes')
|
||||
|
||||
flags.DEFINE_bool("proposal_debug", False, "Print hmc acceptance raes")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def kinetic_energy(velocity):
|
||||
"""Kinetic energy of the current velocity (assuming a standard Gaussian)
|
||||
(x dot x) / 2
|
||||
|
@ -21,6 +21,7 @@ def kinetic_energy(velocity):
|
|||
"""
|
||||
return 0.5 * tf.square(velocity)
|
||||
|
||||
|
||||
def hamiltonian(position, velocity, energy_function):
|
||||
"""Computes the Hamiltonian of the current position, velocity pair
|
||||
|
||||
|
@ -44,13 +45,12 @@ def hamiltonian(position, velocity, energy_function):
|
|||
"""
|
||||
batch_size = tf.shape(velocity)[0]
|
||||
kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1))
|
||||
return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1])
|
||||
return tf.squeeze(energy_function(position)) + tf.reduce_sum(
|
||||
kinetic_energy_flat, axis=[1]
|
||||
)
|
||||
|
||||
def leapfrog_step(x0,
|
||||
v0,
|
||||
neg_log_posterior,
|
||||
step_size,
|
||||
num_steps):
|
||||
|
||||
def leapfrog_step(x0, v0, neg_log_posterior, step_size, num_steps):
|
||||
|
||||
# Start by updating the velocity a half-step
|
||||
v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0]
|
||||
|
@ -83,10 +83,8 @@ def leapfrog_step(x0,
|
|||
# return new proposal state
|
||||
return x, v
|
||||
|
||||
def hmc(initial_x,
|
||||
step_size,
|
||||
num_steps,
|
||||
neg_log_posterior):
|
||||
|
||||
def hmc(initial_x, step_size, num_steps, neg_log_posterior):
|
||||
"""Summary
|
||||
|
||||
Parameters
|
||||
|
@ -107,11 +105,13 @@ def hmc(initial_x,
|
|||
"""
|
||||
|
||||
v0 = tf.random_normal(tf.shape(initial_x))
|
||||
x, v = leapfrog_step(initial_x,
|
||||
v0,
|
||||
step_size=step_size,
|
||||
num_steps=num_steps,
|
||||
neg_log_posterior=neg_log_posterior)
|
||||
x, v = leapfrog_step(
|
||||
initial_x,
|
||||
v0,
|
||||
step_size=step_size,
|
||||
num_steps=num_steps,
|
||||
neg_log_posterior=neg_log_posterior,
|
||||
)
|
||||
|
||||
orig = hamiltonian(initial_x, v0, neg_log_posterior)
|
||||
current = hamiltonian(x, v, neg_log_posterior)
|
||||
|
@ -119,10 +119,12 @@ def hmc(initial_x,
|
|||
prob_accept = tf.exp(orig - current)
|
||||
|
||||
if FLAGS.proposal_debug:
|
||||
prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))])
|
||||
prob_accept = tf.Print(
|
||||
prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))]
|
||||
)
|
||||
|
||||
uniform = tf.random_uniform(tf.shape(prob_accept))
|
||||
keep_mask = (prob_accept > uniform)
|
||||
keep_mask = prob_accept > uniform
|
||||
# print(keep_mask.get_shape())
|
||||
|
||||
x_new = tf.where(keep_mask, x, initial_x)
|
||||
|
|
|
@ -1,24 +1,28 @@
|
|||
from models import ResNet128
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from tensorflow.python.platform import flags
|
||||
import tensorflow as tf
|
||||
|
||||
import imageio
|
||||
from utils import optimistic_restore
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from models import ResNet128
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling')
|
||||
flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics')
|
||||
flags.DEFINE_integer('batch_size', 16, 'number of steps to run')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
|
||||
flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model')
|
||||
flags.DEFINE_bool('cclass', True, 'conditional models')
|
||||
flags.DEFINE_bool('use_attention', False, 'using attention')
|
||||
flags.DEFINE_string(
|
||||
"logdir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_integer("num_steps", 200, "num of steps for conditional imagenet sampling")
|
||||
flags.DEFINE_float("step_lr", 180.0, "step size for Langevin dynamics")
|
||||
flags.DEFINE_integer("batch_size", 16, "number of steps to run")
|
||||
flags.DEFINE_string("exp", "default", "name of experiments")
|
||||
flags.DEFINE_integer("resume_iter", -1, "iteration to resume training from")
|
||||
flags.DEFINE_bool(
|
||||
"spec_norm", True, "whether to use spectral normalization in weights in a model"
|
||||
)
|
||||
flags.DEFINE_bool("cclass", True, "conditional models")
|
||||
flags.DEFINE_bool("use_attention", False, "using attention")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def rescale_im(im):
|
||||
return np.clip(im * 256, 0, 255).astype(np.uint8)
|
||||
|
||||
|
@ -32,12 +36,11 @@ if __name__ == "__main__":
|
|||
weights = model.construct_weights("context_0")
|
||||
|
||||
x_mod = X_NOISE
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
|
||||
mean=0.0,
|
||||
stddev=0.005)
|
||||
x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
|
||||
|
||||
energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL,
|
||||
reuse=True, stop_at_grad=False, stop_batch=True)
|
||||
energy_noise = energy_start = model.forward(
|
||||
x_mod, weights, label=LABEL, reuse=True, stop_at_grad=False, stop_batch=True
|
||||
)
|
||||
|
||||
x_grad = tf.gradients(energy_noise, [x_mod])[0]
|
||||
energy_noise_old = energy_noise
|
||||
|
@ -54,20 +57,23 @@ if __name__ == "__main__":
|
|||
saver = loader = tf.train.Saver()
|
||||
|
||||
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
||||
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
|
||||
model_file = osp.join(logdir, "model_{}".format(FLAGS.resume_iter))
|
||||
saver.restore(sess, model_file)
|
||||
|
||||
lx = np.random.permutation(1000)[:16]
|
||||
ims = []
|
||||
|
||||
# What to initialize sampling with.
|
||||
# What to initialize sampling with.
|
||||
x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3))
|
||||
labels = np.eye(1000)[lx]
|
||||
|
||||
for i in range(FLAGS.num_steps):
|
||||
e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels})
|
||||
ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3)))
|
||||
|
||||
imageio.mimwrite('sample.gif', ims)
|
||||
|
||||
e, x_mod = sess.run([energy_noise, x_output], {X_NOISE: x_mod, LABEL: labels})
|
||||
ims.append(
|
||||
rescale_im(x_mod)
|
||||
.reshape((4, 4, 128, 128, 3))
|
||||
.transpose((0, 2, 1, 3, 4))
|
||||
.reshape((512, 512, 3))
|
||||
)
|
||||
|
||||
imageio.mimwrite("sample.gif", ims)
|
||||
|
|
|
@ -13,14 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Image pre-processing utilities.
|
||||
"""
|
||||
"""Image pre-processing utilities."""
|
||||
import tensorflow as tf
|
||||
|
||||
IMAGE_DEPTH = 3 # color images
|
||||
|
||||
IMAGE_DEPTH = 3 # color images
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# _R_MEAN = 123.68
|
||||
# _G_MEAN = 116.78
|
||||
|
@ -35,303 +32,318 @@ _RESIZE_MIN = 128
|
|||
|
||||
|
||||
def _decode_crop_and_flip(image_buffer, bbox, num_channels):
|
||||
"""Crops the given image to a random part of the image, and randomly flips.
|
||||
"""Crops the given image to a random part of the image, and randomly flips.
|
||||
|
||||
We use the fused decode_and_crop op, which performs better than the two ops
|
||||
used separately in series, but note that this requires that the image be
|
||||
passed in as an un-decoded string Tensor.
|
||||
We use the fused decode_and_crop op, which performs better than the two ops
|
||||
used separately in series, but note that this requires that the image be
|
||||
passed in as an un-decoded string Tensor.
|
||||
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
|
||||
"""
|
||||
# A large fraction of image datasets contain a human-annotated bounding box
|
||||
# delineating the region of the image containing the object of interest. We
|
||||
# choose to create a new bounding box for the object which is a randomly
|
||||
# distorted version of the human-annotated bounding box that obeys an
|
||||
# allowed range of aspect ratios, sizes and overlap with the human-annotated
|
||||
# bounding box. If no box is supplied, then we assume the bounding box is
|
||||
# the entire image.
|
||||
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
||||
tf.image.extract_jpeg_shape(image_buffer),
|
||||
bounding_boxes=bbox,
|
||||
min_object_covered=0.1,
|
||||
aspect_ratio_range=[0.75, 1.33],
|
||||
area_range=[0.05, 1.0],
|
||||
max_attempts=100,
|
||||
use_image_if_no_bounding_boxes=True)
|
||||
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
||||
"""
|
||||
# A large fraction of image datasets contain a human-annotated bounding box
|
||||
# delineating the region of the image containing the object of interest. We
|
||||
# choose to create a new bounding box for the object which is a randomly
|
||||
# distorted version of the human-annotated bounding box that obeys an
|
||||
# allowed range of aspect ratios, sizes and overlap with the human-annotated
|
||||
# bounding box. If no box is supplied, then we assume the bounding box is
|
||||
# the entire image.
|
||||
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
||||
tf.image.extract_jpeg_shape(image_buffer),
|
||||
bounding_boxes=bbox,
|
||||
min_object_covered=0.1,
|
||||
aspect_ratio_range=[0.75, 1.33],
|
||||
area_range=[0.05, 1.0],
|
||||
max_attempts=100,
|
||||
use_image_if_no_bounding_boxes=True,
|
||||
)
|
||||
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
||||
|
||||
# Reassemble the bounding box in the format the crop op requires.
|
||||
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
|
||||
# Reassemble the bounding box in the format the crop op requires.
|
||||
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
|
||||
|
||||
# Use the fused decode and crop op here, which is faster than each in series.
|
||||
cropped = tf.image.decode_and_crop_jpeg(
|
||||
image_buffer, crop_window, channels=num_channels)
|
||||
# Use the fused decode and crop op here, which is faster than each in
|
||||
# series.
|
||||
cropped = tf.image.decode_and_crop_jpeg(
|
||||
image_buffer, crop_window, channels=num_channels
|
||||
)
|
||||
|
||||
# Flip to add a little more random distortion in.
|
||||
cropped = tf.image.random_flip_left_right(cropped)
|
||||
return cropped
|
||||
# Flip to add a little more random distortion in.
|
||||
cropped = tf.image.random_flip_left_right(cropped)
|
||||
return cropped
|
||||
|
||||
|
||||
def _central_crop(image, crop_height, crop_width):
|
||||
"""Performs central crops of the given image list.
|
||||
"""Performs central crops of the given image list.
|
||||
|
||||
Args:
|
||||
image: a 3-D image tensor
|
||||
crop_height: the height of the image following the crop.
|
||||
crop_width: the width of the image following the crop.
|
||||
Args:
|
||||
image: a 3-D image tensor
|
||||
crop_height: the height of the image following the crop.
|
||||
crop_width: the width of the image following the crop.
|
||||
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
Returns:
|
||||
3-D tensor with cropped image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
|
||||
amount_to_be_cropped_h = (height - crop_height)
|
||||
crop_top = amount_to_be_cropped_h // 2
|
||||
amount_to_be_cropped_w = (width - crop_width)
|
||||
crop_left = amount_to_be_cropped_w // 2
|
||||
return tf.slice(
|
||||
image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
|
||||
amount_to_be_cropped_h = height - crop_height
|
||||
crop_top = amount_to_be_cropped_h // 2
|
||||
amount_to_be_cropped_w = width - crop_width
|
||||
crop_left = amount_to_be_cropped_w // 2
|
||||
return tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
|
||||
|
||||
|
||||
def _mean_image_subtraction(image, means, num_channels):
|
||||
"""Subtracts the given means from each image channel.
|
||||
"""Subtracts the given means from each image channel.
|
||||
|
||||
For example:
|
||||
means = [123.68, 116.779, 103.939]
|
||||
image = _mean_image_subtraction(image, means)
|
||||
For example:
|
||||
means = [123.68, 116.779, 103.939]
|
||||
image = _mean_image_subtraction(image, means)
|
||||
|
||||
Note that the rank of `image` must be known.
|
||||
Note that the rank of `image` must be known.
|
||||
|
||||
Args:
|
||||
image: a tensor of size [height, width, C].
|
||||
means: a C-vector of values to subtract from each channel.
|
||||
num_channels: number of color channels in the image that will be distorted.
|
||||
Args:
|
||||
image: a tensor of size [height, width, C].
|
||||
means: a C-vector of values to subtract from each channel.
|
||||
num_channels: number of color channels in the image that will be distorted.
|
||||
|
||||
Returns:
|
||||
the centered image.
|
||||
Returns:
|
||||
the centered image.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rank of `image` is unknown, if `image` has a rank other
|
||||
than three or if the number of channels in `image` doesn't match the
|
||||
number of values in `means`.
|
||||
"""
|
||||
if image.get_shape().ndims != 3:
|
||||
raise ValueError('Input must be of size [height, width, C>0]')
|
||||
Raises:
|
||||
ValueError: If the rank of `image` is unknown, if `image` has a rank other
|
||||
than three or if the number of channels in `image` doesn't match the
|
||||
number of values in `means`.
|
||||
"""
|
||||
if image.get_shape().ndims != 3:
|
||||
raise ValueError("Input must be of size [height, width, C>0]")
|
||||
|
||||
if len(means) != num_channels:
|
||||
raise ValueError('len(means) must match the number of channels')
|
||||
if len(means) != num_channels:
|
||||
raise ValueError("len(means) must match the number of channels")
|
||||
|
||||
# We have a 1-D tensor of means; convert to 3-D.
|
||||
means = tf.expand_dims(tf.expand_dims(means, 0), 0)
|
||||
# We have a 1-D tensor of means; convert to 3-D.
|
||||
means = tf.expand_dims(tf.expand_dims(means, 0), 0)
|
||||
|
||||
return image - means
|
||||
return image - means
|
||||
|
||||
|
||||
def _smallest_size_at_least(height, width, resize_min):
|
||||
"""Computes new shape with the smallest side equal to `smallest_side`.
|
||||
"""Computes new shape with the smallest side equal to `smallest_side`.
|
||||
|
||||
Computes new shape with the smallest side equal to `smallest_side` while
|
||||
preserving the original aspect ratio.
|
||||
Computes new shape with the smallest side equal to `smallest_side` while
|
||||
preserving the original aspect ratio.
|
||||
|
||||
Args:
|
||||
height: an int32 scalar tensor indicating the current height.
|
||||
width: an int32 scalar tensor indicating the current width.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
Args:
|
||||
height: an int32 scalar tensor indicating the current height.
|
||||
width: an int32 scalar tensor indicating the current width.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
|
||||
Returns:
|
||||
new_height: an int32 scalar tensor indicating the new height.
|
||||
new_width: an int32 scalar tensor indicating the new width.
|
||||
"""
|
||||
resize_min = tf.cast(resize_min, tf.float32)
|
||||
Returns:
|
||||
new_height: an int32 scalar tensor indicating the new height.
|
||||
new_width: an int32 scalar tensor indicating the new width.
|
||||
"""
|
||||
resize_min = tf.cast(resize_min, tf.float32)
|
||||
|
||||
# Convert to floats to make subsequent calculations go smoothly.
|
||||
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
|
||||
# Convert to floats to make subsequent calculations go smoothly.
|
||||
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
|
||||
|
||||
smaller_dim = tf.minimum(height, width)
|
||||
scale_ratio = resize_min / smaller_dim
|
||||
smaller_dim = tf.minimum(height, width)
|
||||
scale_ratio = resize_min / smaller_dim
|
||||
|
||||
# Convert back to ints to make heights and widths that TF ops will accept.
|
||||
new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
|
||||
new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
|
||||
# Convert back to ints to make heights and widths that TF ops will accept.
|
||||
new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
|
||||
new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
|
||||
|
||||
return new_height, new_width
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
def _aspect_preserving_resize(image, resize_min):
|
||||
"""Resize images preserving the original aspect ratio.
|
||||
"""Resize images preserving the original aspect ratio.
|
||||
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
resize_min: A python integer or scalar `Tensor` indicating the size of
|
||||
the smallest side after resize.
|
||||
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image.
|
||||
"""
|
||||
shape = tf.shape(input=image)
|
||||
height, width = shape[0], shape[1]
|
||||
|
||||
new_height, new_width = _smallest_size_at_least(height, width, resize_min)
|
||||
new_height, new_width = _smallest_size_at_least(height, width, resize_min)
|
||||
|
||||
return _resize_image(image, new_height, new_width)
|
||||
return _resize_image(image, new_height, new_width)
|
||||
|
||||
|
||||
def _resize_image(image, height, width):
|
||||
"""Simple wrapper around tf.resize_images.
|
||||
"""Simple wrapper around tf.resize_images.
|
||||
|
||||
This is primarily to make sure we use the same `ResizeMethod` and other
|
||||
details each time.
|
||||
This is primarily to make sure we use the same `ResizeMethod` and other
|
||||
details each time.
|
||||
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
height: The target height for the resized image.
|
||||
width: The target width for the resized image.
|
||||
Args:
|
||||
image: A 3-D image `Tensor`.
|
||||
height: The target height for the resized image.
|
||||
width: The target width for the resized image.
|
||||
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image. The first two
|
||||
dimensions have the shape [height, width].
|
||||
"""
|
||||
return tf.image.resize_images(
|
||||
image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
|
||||
align_corners=False)
|
||||
Returns:
|
||||
resized_image: A 3-D tensor containing the resized image. The first two
|
||||
dimensions have the shape [height, width].
|
||||
"""
|
||||
return tf.image.resize_images(
|
||||
image,
|
||||
[height, width],
|
||||
method=tf.image.ResizeMethod.BILINEAR,
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
|
||||
def preprocess_image(image_buffer, bbox, output_height, output_width,
|
||||
num_channels, is_training=False):
|
||||
"""Preprocesses the given image.
|
||||
def preprocess_image(
|
||||
image_buffer, bbox, output_height, output_width, num_channels, is_training=False
|
||||
):
|
||||
"""Preprocesses the given image.
|
||||
|
||||
Preprocessing includes decoding, cropping, and resizing for both training
|
||||
and eval images. Training preprocessing, however, introduces some random
|
||||
distortion of the image to improve accuracy.
|
||||
Preprocessing includes decoding, cropping, and resizing for both training
|
||||
and eval images. Training preprocessing, however, introduces some random
|
||||
distortion of the image to improve accuracy.
|
||||
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
output_height: The height of the image after preprocessing.
|
||||
output_width: The width of the image after preprocessing.
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
is_training: `True` if we're preprocessing the image for training and
|
||||
`False` otherwise.
|
||||
Args:
|
||||
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
output_height: The height of the image after preprocessing.
|
||||
output_width: The width of the image after preprocessing.
|
||||
num_channels: Integer depth of the image buffer for decoding.
|
||||
is_training: `True` if we're preprocessing the image for training and
|
||||
`False` otherwise.
|
||||
|
||||
Returns:
|
||||
A preprocessed image.
|
||||
"""
|
||||
if is_training:
|
||||
# For training, we want to randomize some of the distortions.
|
||||
image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
|
||||
image = _resize_image(image, output_height, output_width)
|
||||
else:
|
||||
# For validation, we want to decode, resize, then just crop the middle.
|
||||
image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
|
||||
image = _aspect_preserving_resize(image, _RESIZE_MIN)
|
||||
print(image)
|
||||
image = _central_crop(image, output_height, output_width)
|
||||
Returns:
|
||||
A preprocessed image.
|
||||
"""
|
||||
if is_training:
|
||||
# For training, we want to randomize some of the distortions.
|
||||
image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
|
||||
image = _resize_image(image, output_height, output_width)
|
||||
else:
|
||||
# For validation, we want to decode, resize, then just crop the middle.
|
||||
image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
|
||||
image = _aspect_preserving_resize(image, _RESIZE_MIN)
|
||||
print(image)
|
||||
image = _central_crop(image, output_height, output_width)
|
||||
|
||||
image.set_shape([output_height, output_width, num_channels])
|
||||
image.set_shape([output_height, output_width, num_channels])
|
||||
|
||||
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
|
||||
return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
|
||||
|
||||
|
||||
def parse_example_proto(example_serialized):
|
||||
"""Parses an Example proto containing a training example of an image.
|
||||
"""Parses an Example proto containing a training example of an image.
|
||||
|
||||
The output of the build_image_data.py image preprocessing script is a dataset
|
||||
containing serialized Example protocol buffers. Each Example proto contains
|
||||
the following fields:
|
||||
The output of the build_image_data.py image preprocessing script is a dataset
|
||||
containing serialized Example protocol buffers. Each Example proto contains
|
||||
the following fields:
|
||||
|
||||
image/height: 462
|
||||
image/width: 581
|
||||
image/colorspace: 'RGB'
|
||||
image/channels: 3
|
||||
image/class/label: 615
|
||||
image/class/synset: 'n03623198'
|
||||
image/class/text: 'knee pad'
|
||||
image/object/bbox/xmin: 0.1
|
||||
image/object/bbox/xmax: 0.9
|
||||
image/object/bbox/ymin: 0.2
|
||||
image/object/bbox/ymax: 0.6
|
||||
image/object/bbox/label: 615
|
||||
image/format: 'JPEG'
|
||||
image/filename: 'ILSVRC2012_val_00041207.JPEG'
|
||||
image/encoded: <JPEG encoded string>
|
||||
image/height: 462
|
||||
image/width: 581
|
||||
image/colorspace: 'RGB'
|
||||
image/channels: 3
|
||||
image/class/label: 615
|
||||
image/class/synset: 'n03623198'
|
||||
image/class/text: 'knee pad'
|
||||
image/object/bbox/xmin: 0.1
|
||||
image/object/bbox/xmax: 0.9
|
||||
image/object/bbox/ymin: 0.2
|
||||
image/object/bbox/ymax: 0.6
|
||||
image/object/bbox/label: 615
|
||||
image/format: 'JPEG'
|
||||
image/filename: 'ILSVRC2012_val_00041207.JPEG'
|
||||
image/encoded: <JPEG encoded string>
|
||||
|
||||
Args:
|
||||
example_serialized: scalar Tensor tf.string containing a serialized
|
||||
Example protocol buffer.
|
||||
Args:
|
||||
example_serialized: scalar Tensor tf.string containing a serialized
|
||||
Example protocol buffer.
|
||||
|
||||
Returns:
|
||||
image_buffer: Tensor tf.string containing the contents of a JPEG file.
|
||||
label: Tensor tf.int32 containing the label.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
text: Tensor tf.string containing the human-readable label.
|
||||
"""
|
||||
# Dense features in Example proto.
|
||||
feature_map = {
|
||||
'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
|
||||
default_value=''),
|
||||
'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
|
||||
default_value=-1),
|
||||
'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
|
||||
default_value=''),
|
||||
}
|
||||
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
|
||||
# Sparse features in Example proto.
|
||||
feature_map.update(
|
||||
{k: sparse_float32 for k in ['image/object/bbox/xmin',
|
||||
'image/object/bbox/ymin',
|
||||
'image/object/bbox/xmax',
|
||||
'image/object/bbox/ymax']})
|
||||
Returns:
|
||||
image_buffer: Tensor tf.string containing the contents of a JPEG file.
|
||||
label: Tensor tf.int32 containing the label.
|
||||
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
|
||||
where each coordinate is [0, 1) and the coordinates are arranged as
|
||||
[ymin, xmin, ymax, xmax].
|
||||
text: Tensor tf.string containing the human-readable label.
|
||||
"""
|
||||
# Dense features in Example proto.
|
||||
feature_map = {
|
||||
"image/encoded": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
|
||||
"image/class/label": tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1),
|
||||
"image/class/text": tf.FixedLenFeature([], dtype=tf.string, default_value=""),
|
||||
}
|
||||
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
|
||||
# Sparse features in Example proto.
|
||||
feature_map.update(
|
||||
{
|
||||
k: sparse_float32
|
||||
for k in [
|
||||
"image/object/bbox/xmin",
|
||||
"image/object/bbox/ymin",
|
||||
"image/object/bbox/xmax",
|
||||
"image/object/bbox/ymax",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
features = tf.parse_single_example(example_serialized, feature_map)
|
||||
label = tf.cast(features['image/class/label'], dtype=tf.int32)
|
||||
features = tf.parse_single_example(example_serialized, feature_map)
|
||||
label = tf.cast(features["image/class/label"], dtype=tf.int32)
|
||||
|
||||
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
|
||||
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
|
||||
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
|
||||
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
|
||||
xmin = tf.expand_dims(features["image/object/bbox/xmin"].values, 0)
|
||||
ymin = tf.expand_dims(features["image/object/bbox/ymin"].values, 0)
|
||||
xmax = tf.expand_dims(features["image/object/bbox/xmax"].values, 0)
|
||||
ymax = tf.expand_dims(features["image/object/bbox/ymax"].values, 0)
|
||||
|
||||
# Note that we impose an ordering of (y, x) just to make life difficult.
|
||||
bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
|
||||
# Note that we impose an ordering of (y, x) just to make life difficult.
|
||||
bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
|
||||
|
||||
# Force the variable number of bounding boxes into the shape
|
||||
# [1, num_boxes, coords].
|
||||
bbox = tf.expand_dims(bbox, 0)
|
||||
bbox = tf.transpose(bbox, [0, 2, 1])
|
||||
# Force the variable number of bounding boxes into the shape
|
||||
# [1, num_boxes, coords].
|
||||
bbox = tf.expand_dims(bbox, 0)
|
||||
bbox = tf.transpose(bbox, [0, 2, 1])
|
||||
|
||||
return features['image/encoded'], label, bbox, features['image/class/text']
|
||||
return features["image/encoded"], label, bbox, features["image/class/text"]
|
||||
|
||||
|
||||
class ImagenetPreprocessor:
|
||||
def __init__(self, image_size, dtype, train):
|
||||
self.image_size = image_size
|
||||
self.dtype = dtype
|
||||
self.train = train
|
||||
def __init__(self, image_size, dtype, train):
|
||||
self.image_size = image_size
|
||||
self.dtype = dtype
|
||||
self.train = train
|
||||
|
||||
def preprocess(self, image_buffer, bbox):
|
||||
# pylint: disable=g-import-not-at-top
|
||||
image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train)
|
||||
return tf.cast(image, self.dtype)
|
||||
|
||||
def parse_and_preprocess(self, value):
|
||||
image_buffer, label_index, bbox, _ = parse_example_proto(value)
|
||||
image = self.preprocess(image_buffer, bbox)
|
||||
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
|
||||
return label_index, image
|
||||
def preprocess(self, image_buffer, bbox):
|
||||
# pylint: disable=g-import-not-at-top
|
||||
image = preprocess_image(
|
||||
image_buffer,
|
||||
bbox,
|
||||
self.image_size,
|
||||
self.image_size,
|
||||
IMAGE_DEPTH,
|
||||
is_training=self.train,
|
||||
)
|
||||
return tf.cast(image, self.dtype)
|
||||
|
||||
def parse_and_preprocess(self, value):
|
||||
image_buffer, label_index, bbox, _ = parse_example_proto(value)
|
||||
image = self.preprocess(image_buffer, bbox)
|
||||
image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
|
||||
return label_index, image
|
||||
|
|
|
@ -1,105 +1,112 @@
|
|||
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# Code derived from
|
||||
# tensorflow/tensorflow/models/image/imagenet/classify_image.py
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import math
|
||||
import os.path
|
||||
import sys
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
from six.moves import urllib
|
||||
import tensorflow as tf
|
||||
import glob
|
||||
import scipy.misc
|
||||
import math
|
||||
import sys
|
||||
|
||||
import horovod.tensorflow as hvd
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from six.moves import urllib
|
||||
|
||||
MODEL_DIR = '/tmp/imagenet'
|
||||
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
MODEL_DIR = "/tmp/imagenet"
|
||||
DATA_URL = (
|
||||
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
|
||||
)
|
||||
softmax = None
|
||||
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.visible_device_list = str(hvd.local_rank())
|
||||
sess = tf.Session(config=config)
|
||||
|
||||
# Call this function with list of images. Each of elements should be a
|
||||
|
||||
# Call this function with list of images. Each of elements should be a
|
||||
# numpy array with values ranging from 0 to 255.
|
||||
def get_inception_score(images, splits=10):
|
||||
# For convenience
|
||||
if len(images[0].shape) != 3:
|
||||
return 0, 0
|
||||
# For convenience
|
||||
if len(images[0].shape) != 3:
|
||||
return 0, 0
|
||||
|
||||
# Bypassing all the assertions so that we don't end prematuraly'
|
||||
# assert(type(images) == list)
|
||||
# assert(type(images[0]) == np.ndarray)
|
||||
# assert(len(images[0].shape) == 3)
|
||||
# assert(np.max(images[0]) > 10)
|
||||
# assert(np.min(images[0]) >= 0.0)
|
||||
inps = []
|
||||
for img in images:
|
||||
img = img.astype(np.float32)
|
||||
inps.append(np.expand_dims(img, 0))
|
||||
bs = 1
|
||||
preds = []
|
||||
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
|
||||
for i in range(n_batches):
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
inp = inps[(i * bs) : min((i + 1) * bs, len(inps))]
|
||||
inp = np.concatenate(inp, 0)
|
||||
pred = sess.run(softmax, {"ExpandDims:0": inp})
|
||||
preds.append(pred)
|
||||
preds = np.concatenate(preds, 0)
|
||||
scores = []
|
||||
for i in range(splits):
|
||||
part = preds[
|
||||
(i * preds.shape[0] // splits) : ((i + 1) * preds.shape[0] // splits), :
|
||||
]
|
||||
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
||||
kl = np.mean(np.sum(kl, 1))
|
||||
scores.append(np.exp(kl))
|
||||
return np.mean(scores), np.std(scores)
|
||||
|
||||
# Bypassing all the assertions so that we don't end prematuraly'
|
||||
# assert(type(images) == list)
|
||||
# assert(type(images[0]) == np.ndarray)
|
||||
# assert(len(images[0].shape) == 3)
|
||||
# assert(np.max(images[0]) > 10)
|
||||
# assert(np.min(images[0]) >= 0.0)
|
||||
inps = []
|
||||
for img in images:
|
||||
img = img.astype(np.float32)
|
||||
inps.append(np.expand_dims(img, 0))
|
||||
bs = 1
|
||||
preds = []
|
||||
n_batches = int(math.ceil(float(len(inps)) / float(bs)))
|
||||
for i in range(n_batches):
|
||||
sys.stdout.write(".")
|
||||
sys.stdout.flush()
|
||||
inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
|
||||
inp = np.concatenate(inp, 0)
|
||||
pred = sess.run(softmax, {'ExpandDims:0': inp})
|
||||
preds.append(pred)
|
||||
preds = np.concatenate(preds, 0)
|
||||
scores = []
|
||||
for i in range(splits):
|
||||
part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
|
||||
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
||||
kl = np.mean(np.sum(kl, 1))
|
||||
scores.append(np.exp(kl))
|
||||
return np.mean(scores), np.std(scores)
|
||||
|
||||
# This function is called automatically.
|
||||
def _init_inception():
|
||||
global softmax
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split('/')[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
||||
filename, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(os.path.join(
|
||||
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name='')
|
||||
# Works with an arbitrary minibatch size.
|
||||
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.set_shape(tf.TensorShape(new_shape))
|
||||
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
|
||||
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
|
||||
softmax = tf.nn.softmax(logits)
|
||||
global softmax
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split("/")[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write(
|
||||
"\r>> Downloading %s %.1f%%"
|
||||
% (filename, float(count * block_size) / float(total_size) * 100.0)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
|
||||
tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(
|
||||
os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
|
||||
) as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name="")
|
||||
# Works with an arbitrary minibatch size.
|
||||
pool3 = sess.graph.get_tensor_by_name("pool_3:0")
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.set_shape(tf.TensorShape(new_shape))
|
||||
w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
|
||||
logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
|
||||
softmax = tf.nn.softmax(logits)
|
||||
|
||||
|
||||
if softmax is None:
|
||||
_init_inception()
|
||||
_init_inception()
|
||||
|
|
1327
EBMs/models.py
1327
EBMs/models.py
File diff suppressed because it is too large
Load diff
|
@ -1,53 +1,56 @@
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from tensorflow.python.platform import flags
|
||||
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
|
||||
import os.path as osp
|
||||
import os
|
||||
from utils import optimistic_restore, remap_restore, optimistic_remap_restore
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
from scipy.misc import imsave
|
||||
from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from baselines.common.tf_util import initialize
|
||||
|
||||
import horovod.tensorflow as hvd
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from baselines.common.tf_util import initialize
|
||||
from data import Cifar10, Imagenet, TFImagenetLoader
|
||||
from fid import get_fid_score
|
||||
from inception import get_inception_score
|
||||
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
|
||||
from scipy.misc import imsave
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from utils import optimistic_remap_restore
|
||||
|
||||
hvd.init()
|
||||
|
||||
from inception import get_inception_score
|
||||
from fid import get_fid_score
|
||||
|
||||
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
|
||||
flags.DEFINE_string('exp', 'default', 'name of experiments')
|
||||
flags.DEFINE_bool('cclass', False, 'whether to condition on class')
|
||||
flags.DEFINE_string(
|
||||
"logdir", "cachedir", "location where log of experiments will be stored"
|
||||
)
|
||||
flags.DEFINE_string("exp", "default", "name of experiments")
|
||||
flags.DEFINE_bool("cclass", False, "whether to condition on class")
|
||||
|
||||
# Architecture settings
|
||||
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
|
||||
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
|
||||
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
|
||||
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
|
||||
flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent')
|
||||
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
|
||||
flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images')
|
||||
flags.DEFINE_integer('batch_size', 512, 'batch size')
|
||||
flags.DEFINE_integer('resume_iter', -1, 'resume iteration')
|
||||
flags.DEFINE_integer('ensemble', 10, 'number of ensembles')
|
||||
flags.DEFINE_integer('im_number', 50000, 'number of ensembles')
|
||||
flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations')
|
||||
flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output')
|
||||
flags.DEFINE_integer('idx', 0, 'save index')
|
||||
flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing')
|
||||
flags.DEFINE_bool('scaled', True, 'whether to scale noise added')
|
||||
flags.DEFINE_bool('large_model', False, 'whether to use a small or large model')
|
||||
flags.DEFINE_bool('larger_model', False, 'Whether to use a large model')
|
||||
flags.DEFINE_bool('wider_model', False, 'Whether to use a large model')
|
||||
flags.DEFINE_bool('single', False, 'single ')
|
||||
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
|
||||
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull')
|
||||
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
|
||||
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
||||
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
||||
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
||||
flags.DEFINE_float("step_lr", 10.0, "Size of steps for gradient descent")
|
||||
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
|
||||
flags.DEFINE_float("proj_norm", 0.05, "Maximum change of input images")
|
||||
flags.DEFINE_integer("batch_size", 512, "batch size")
|
||||
flags.DEFINE_integer("resume_iter", -1, "resume iteration")
|
||||
flags.DEFINE_integer("ensemble", 10, "number of ensembles")
|
||||
flags.DEFINE_integer("im_number", 50000, "number of ensembles")
|
||||
flags.DEFINE_integer("repeat_scale", 100, "number of repeat iterations")
|
||||
flags.DEFINE_float("noise_scale", 0.005, "amount of noise to output")
|
||||
flags.DEFINE_integer("idx", 0, "save index")
|
||||
flags.DEFINE_integer("nomix", 10, "number of intervals to stop mixing")
|
||||
flags.DEFINE_bool("scaled", True, "whether to scale noise added")
|
||||
flags.DEFINE_bool("large_model", False, "whether to use a small or large model")
|
||||
flags.DEFINE_bool("larger_model", False, "Whether to use a large model")
|
||||
flags.DEFINE_bool("wider_model", False, "Whether to use a large model")
|
||||
flags.DEFINE_bool("single", False, "single ")
|
||||
flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
|
||||
flags.DEFINE_string("dataset", "cifar10", "cifar10 or imagenet or imagenetfull")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class InceptionReplayBuffer(object):
|
||||
def __init__(self, size):
|
||||
"""Create Replay buffer.
|
||||
|
@ -72,14 +75,16 @@ class InceptionReplayBuffer(object):
|
|||
self._label_storage.extend(list(labels))
|
||||
else:
|
||||
if batch_size + self._next_idx < self._maxsize:
|
||||
self._storage[self._next_idx:self._next_idx+batch_size] = list(ims)
|
||||
self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels)
|
||||
self._storage[self._next_idx : self._next_idx + batch_size] = list(ims)
|
||||
self._label_storage[self._next_idx : self._next_idx + batch_size] = (
|
||||
list(labels)
|
||||
)
|
||||
else:
|
||||
split_idx = self._maxsize - self._next_idx
|
||||
self._storage[self._next_idx:] = list(ims)[:split_idx]
|
||||
self._storage[:batch_size-split_idx] = list(ims)[split_idx:]
|
||||
self._label_storage[self._next_idx:] = list(labels)[:split_idx]
|
||||
self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:]
|
||||
self._storage[self._next_idx :] = list(ims)[:split_idx]
|
||||
self._storage[: batch_size - split_idx] = list(ims)[split_idx:]
|
||||
self._label_storage[self._next_idx :] = list(labels)[:split_idx]
|
||||
self._label_storage[: batch_size - split_idx] = list(labels)[split_idx:]
|
||||
|
||||
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
|
||||
|
||||
|
@ -123,12 +128,13 @@ class InceptionReplayBuffer(object):
|
|||
def rescale_im(im):
|
||||
return np.clip(im * 256, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def compute_inception(sess, target_vars):
|
||||
X_START = target_vars['X_START']
|
||||
Y_GT = target_vars['Y_GT']
|
||||
X_finals = target_vars['X_finals']
|
||||
NOISE_SCALE = target_vars['NOISE_SCALE']
|
||||
energy_noise = target_vars['energy_noise']
|
||||
X_START = target_vars["X_START"]
|
||||
Y_GT = target_vars["Y_GT"]
|
||||
X_finals = target_vars["X_finals"]
|
||||
NOISE_SCALE = target_vars["NOISE_SCALE"]
|
||||
energy_noise = target_vars["energy_noise"]
|
||||
|
||||
size = FLAGS.im_number
|
||||
num_steps = size // 1000
|
||||
|
@ -136,16 +142,21 @@ def compute_inception(sess, target_vars):
|
|||
images = []
|
||||
test_ims = []
|
||||
|
||||
|
||||
if FLAGS.dataset == "cifar10":
|
||||
test_dataset = Cifar10(full=True, noise=False)
|
||||
elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
|
||||
test_dataset = Imagenet(train=False)
|
||||
|
||||
if FLAGS.dataset != "imagenetfull":
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_workers=4,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)
|
||||
test_dataloader = TFImagenetLoader("test", FLAGS.batch_size, 0, 1)
|
||||
|
||||
for data_corrupt, data, label_gt in tqdm(test_dataloader):
|
||||
data = data.numpy()
|
||||
|
@ -155,7 +166,6 @@ def compute_inception(sess, target_vars):
|
|||
test_ims = test_ims[:60000]
|
||||
break
|
||||
|
||||
|
||||
# n = min(len(images), len(test_ims))
|
||||
print(len(test_ims))
|
||||
# fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
|
||||
|
@ -187,12 +197,14 @@ def compute_inception(sess, target_vars):
|
|||
x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
|
||||
label = np.random.randint(0, classes, (FLAGS.batch_size))
|
||||
label = identity[label]
|
||||
x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
|
||||
x_new = sess.run(
|
||||
[x_final], {X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale}
|
||||
)[0]
|
||||
data_buffer.add(x_new, label)
|
||||
else:
|
||||
(x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
|
||||
keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
|
||||
label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
|
||||
keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99
|
||||
label_keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9
|
||||
label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
|
||||
label_corrupt = identity[label_corrupt]
|
||||
x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
|
||||
|
@ -203,7 +215,10 @@ def compute_inception(sess, target_vars):
|
|||
# else:
|
||||
# noise_scale = [0.7]
|
||||
|
||||
x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
|
||||
x_new, e_noise = sess.run(
|
||||
[x_final, energy_noise],
|
||||
{X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale},
|
||||
)
|
||||
data_buffer.set_elms(idx, x_new, label)
|
||||
|
||||
if FLAGS.im_number != 50000:
|
||||
|
@ -216,14 +231,22 @@ def compute_inception(sess, target_vars):
|
|||
|
||||
images.extend(list(ims))
|
||||
|
||||
saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx))
|
||||
saveim = osp.join("sandbox_cachedir", FLAGS.exp, "test{}.png".format(FLAGS.idx))
|
||||
|
||||
ims = ims[:100]
|
||||
|
||||
if FLAGS.dataset != "imagenetfull":
|
||||
im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
|
||||
im_panel = (
|
||||
ims.reshape((10, 10, 32, 32, 3))
|
||||
.transpose((0, 2, 1, 3, 4))
|
||||
.reshape((320, 320, 3))
|
||||
)
|
||||
else:
|
||||
im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3))
|
||||
im_panel = (
|
||||
ims.reshape((10, 10, 128, 128, 3))
|
||||
.transpose((0, 2, 1, 3, 4))
|
||||
.reshape((1280, 1280, 3))
|
||||
)
|
||||
imsave(saveim, im_panel)
|
||||
|
||||
print("Saved image!!!!")
|
||||
|
@ -237,8 +260,6 @@ def compute_inception(sess, target_vars):
|
|||
print("FID of score {}".format(fid))
|
||||
|
||||
|
||||
|
||||
|
||||
def main(model_list):
|
||||
|
||||
if FLAGS.dataset == "imagenetfull":
|
||||
|
@ -259,45 +280,55 @@ def main(model_list):
|
|||
weights = []
|
||||
|
||||
for i, model_num in enumerate(model_list):
|
||||
weight = model.construct_weights('context_{}'.format(i))
|
||||
weight = model.construct_weights("context_{}".format(i))
|
||||
initialize()
|
||||
save_file = osp.join(logdir, 'model_{}'.format(model_num))
|
||||
save_file = osp.join(logdir, "model_{}".format(model_num))
|
||||
|
||||
v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i))
|
||||
v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list}
|
||||
v_list = tf.get_collection(
|
||||
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(i)
|
||||
)
|
||||
v_map = {
|
||||
(v.name.replace("context_{}".format(i), "context_0")[:-2]): v
|
||||
for v in v_list
|
||||
}
|
||||
saver = tf.train.Saver(v_map)
|
||||
try:
|
||||
saver.restore(sess, save_file)
|
||||
except:
|
||||
except BaseException:
|
||||
optimistic_remap_restore(sess, save_file, i)
|
||||
weights.append(weight)
|
||||
|
||||
|
||||
if FLAGS.dataset == "imagenetfull":
|
||||
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32)
|
||||
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
|
||||
else:
|
||||
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
|
||||
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
||||
|
||||
if FLAGS.dataset == "cifar10":
|
||||
Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
|
||||
Y_GT = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
||||
else:
|
||||
Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
|
||||
Y_GT = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
||||
|
||||
NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32)
|
||||
|
||||
X_finals = []
|
||||
|
||||
|
||||
# Seperate loops
|
||||
for weight in weights:
|
||||
X = X_START
|
||||
|
||||
steps = tf.constant(0)
|
||||
c = lambda i, x: tf.less(i, FLAGS.num_steps)
|
||||
|
||||
def c(i, x):
|
||||
return tf.less(i, FLAGS.num_steps)
|
||||
|
||||
def langevin_step(counter, X):
|
||||
scale_rate = 1
|
||||
|
||||
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE)
|
||||
X = X + tf.random_normal(
|
||||
tf.shape(X),
|
||||
mean=0.0,
|
||||
stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE,
|
||||
)
|
||||
|
||||
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
|
||||
x_grad = tf.gradients(energy_noise, [X])[0]
|
||||
|
@ -305,7 +336,7 @@ def main(model_list):
|
|||
if FLAGS.proj_norm != 0.0:
|
||||
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
||||
|
||||
X = X - FLAGS.step_lr * x_grad * scale_rate
|
||||
X = X - FLAGS.step_lr * x_grad * scale_rate
|
||||
X = tf.clip_by_value(X, 0, 1)
|
||||
|
||||
counter = counter + 1
|
||||
|
@ -318,16 +349,16 @@ def main(model_list):
|
|||
X_finals.append(X_final)
|
||||
|
||||
target_vars = {}
|
||||
target_vars['X_START'] = X_START
|
||||
target_vars['Y_GT'] = Y_GT
|
||||
target_vars['X_finals'] = X_finals
|
||||
target_vars['NOISE_SCALE'] = NOISE_SCALE
|
||||
target_vars['energy_noise'] = energy_noise
|
||||
target_vars["X_START"] = X_START
|
||||
target_vars["Y_GT"] = Y_GT
|
||||
target_vars["X_finals"] = X_finals
|
||||
target_vars["NOISE_SCALE"] = NOISE_SCALE
|
||||
target_vars["energy_noise"] = energy_noise
|
||||
|
||||
compute_inception(sess, target_vars)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# model_list = [117000, 116700]
|
||||
model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)]
|
||||
model_list = [FLAGS.resume_iter - 300 * i for i in range(FLAGS.ensemble)]
|
||||
main(model_list)
|
||||
|
|
711
EBMs/train.py
711
EBMs/train.py
File diff suppressed because it is too large
Load diff
737
EBMs/utils.py
737
EBMs/utils.py
File diff suppressed because it is too large
Load diff
22
Makefile
Normal file
22
Makefile
Normal file
|
@ -0,0 +1,22 @@
|
|||
.PHONY: install lint clean
|
||||
|
||||
install:
|
||||
@echo "creating virtual environment..."
|
||||
python3 -m venv venv
|
||||
@echo "run: source venv/bin/activate"
|
||||
venv/bin/pip3 install -r scripts/requirements.txt
|
||||
|
||||
lint:
|
||||
venv/bin/python3 scripts/auto_fix.py
|
||||
|
||||
clean:
|
||||
@echo "🧹 cleaning build artifacts and cache..."
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
find . -type f -name "*.pyo" -delete 2>/dev/null || true
|
||||
find . -type f -name "*.pyd" -delete 2>/dev/null || true
|
||||
find . -type f -name ".coverage" -delete 2>/dev/null || true
|
||||
find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@echo "✨ cleanup complete!"
|
22
README.md
22
README.md
|
@ -1,19 +1,27 @@
|
|||
## resources and experiments on autonomous agents
|
||||
|
||||
## the AI toolkit
|
||||
|
||||
<br>
|
||||
|
||||
* **[⬛ ai && ml tl; dr](deep_learning)**
|
||||
#### research
|
||||
|
||||
* **[⬛ machine learning history](deep_learning)**
|
||||
* **[⬛ large language models](llms)**
|
||||
* **[⬛ agents on blockchains](crypto_agents)**
|
||||
* **[⬛ on quantum computing](EBMs)**
|
||||
- my adaptation of openai's implicit generation and generalization in energy based models
|
||||
|
||||
<br>
|
||||
|
||||
#### experiments
|
||||
|
||||
* **[⬛ agents on blockchains](crypto_agents)**
|
||||
* **[⬛ on quantum computing](EBMs)** (my adaptation of openAI's implicit generation and generalization in energy based
|
||||
models)
|
||||
|
||||
<br>
|
||||
|
||||
---
|
||||
|
||||
### cool resources
|
||||
### cool discussions
|
||||
|
||||
<br>
|
||||
|
||||
* **[vub's response to AI 2027 and his take on defense (2025)](https://vitalik.eth.limo/general/2025/07/10/2027.html)**
|
||||
* **[mr. vp jd vance at the ai action summit in paris (2025)](https://www.youtube.com/watch?v=MnKsxnP2IVk)**
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
## crypto agents
|
||||
## agents on blockchains
|
||||
|
||||
<br>
|
||||
|
||||
* **[basic strategy workflow](strategy_workflow)**
|
||||
* **[basic strategy workflow (2023)](strategy_workflow)**
|
||||
* **[trading on gmx (2023)](trading_on_gmx.md)**
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -12,7 +13,6 @@
|
|||
|
||||
<br>
|
||||
|
||||
|
||||
##### projects
|
||||
|
||||
* **[ritual.net](https://ritual.net/)**
|
||||
|
@ -27,26 +27,37 @@
|
|||
|
||||
##### readings
|
||||
|
||||
* **[the internet's notary public: why verifiability matters, by axal](https://axal.substack.com/p/the-internets-notary-public-why-verifiability)**
|
||||
* **[cryptos role in the ai revolution, by pantera](https://panteracapital.com/blockchain-letter/cryptos-role-in-the-ai-revolution/)**
|
||||
* **[the promise and challenges of crypto + ai applications, by vub](https://vitalik.eth.limo/general/2024/01/30/cryptoai.html)**
|
||||
* **[on training defi agents with markov chains, by bt3gl](https://mirror.xyz/go-outside.eth/DKaWYobU7q3EvZw8x01J7uEmF_E8PfNN27j0VgxQhNQ)**
|
||||
* **[microsoft notes on ai agents](https://github.com/microsoft/generative-ai-for-beginners/tree/main/17-ai-agents)**
|
||||
* **[the internet's notary public: why verifiability matters, by
|
||||
axal](https://axal.substack.com/p/the-internets-notary-public-why-verifiability)**
|
||||
* **[cryptos role in the ai revolution, by
|
||||
pantera](https://panteracapital.com/blockchain-letter/cryptos-role-in-the-ai-revolution/)**
|
||||
* **[the promise and challenges of crypto + ai applications, by
|
||||
vub](https://vitalik.eth.limo/general/2024/01/30/cryptoai.html)**
|
||||
* **[on training defi agents with markov chains, by
|
||||
bt3gl](https://mirror.xyz/go-outside.eth/DKaWYobU7q3EvZw8x01J7uEmF_E8PfNN27j0VgxQhNQ)**
|
||||
|
||||
<br>
|
||||
|
||||
##### metas
|
||||
|
||||
* **eth denver 2025 ai + agentic metas**:
|
||||
* **[a new dawn for decentralized and the convergence of ai x blockchain, by n. ameline](https://www.youtube.com/watch?v=HQuGtN9zidQ)**
|
||||
* **[upgrading agent infra with onchain perpetualaAgents, by kd conway](https://www.youtube.com/watch?v=hDayCeDA5fI)**
|
||||
* **[ai meets blockchain: payment, multi-Agent & distributed inference, by o. jaros](https://www.youtube.com/watch?v=aPTosw4hrY0)**
|
||||
* **[a new dawn for decentralized and the convergence of ai x blockchain, by n.
|
||||
* **[upgrading agent infra with onchain perpetualaAgents, by kd conway](https://www.youtube.com/watch?v=hDayCeDA5fI)**
|
||||
* **[ai meets blockchain: payment, multi-Agent & distributed inference, by o.
|
||||
* **[from ai agents to agentic economies, by r. bodkin](https://www.youtube.com/watch?v=Q7eaYJ9aPpI)**
|
||||
* **[building ai agents and agent economies, by d. minarsch](https://www.youtube.com/watch?v=tDVK2Q5RY0c)**
|
||||
* **[building ai agents on top of hedera, by j. hall](https://www.youtube.com/watch?v=h8D6vi2m8LQ)**
|
||||
* **[identity, privacy, and security in the new age of ai agents, by m. csernai](https://www.youtube.com/watch?v=vMcaot04RQo)**
|
||||
* **[identity, privacy, and security in the new age of ai agents, by m.
|
||||
* **[verifiable ai agents and world superintelligence, by k. wong](https://www.youtube.com/watch?v=ngkp7HTj_4A)**
|
||||
* **[how web3 can compete in ai, by g. narula](https://www.youtube.com/watch?v=oLLM1I-3fDU)**
|
||||
* **[shade agents, by m. lockyer](https://www.youtube.com/watch?v=PEfJnCtrbMU)**
|
||||
* **[how ai will enable the next generatin of intent, by i. yang](https://www.youtube.com/watch?v=fbc3DpI6jiA)**
|
||||
* **[2025 is the year of agents, by i. polosukhin](https://www.youtube.com/watch?v=jPyzVNcQMKw)**
|
||||
* **[our awesome decentralized AI](https://github.com/shadowy-forest/awesome-decentralized-ai)**
|
||||
|
||||
<br>
|
||||
|
||||
##### books
|
||||
|
||||
* **[advances in financial machine learning](../books/advances_in_financial_machine_learning.pdf)**
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
<br>
|
||||
|
||||
<p align="center">
|
||||
<img width="854" src="https://user-images.githubusercontent.com/1130416/227752772-5d739fd8-1b5c-4841-a52a-7cda308fc4df.png">
|
||||
<img width="854"
|
||||
src="https://user-images.githubusercontent.com/1130416/227752772-5d739fd8-1b5c-4841-a52a-7cda308fc4df.png">
|
||||
</p>
|
||||
|
||||
<br>
|
||||
|
|
|
@ -5,59 +5,81 @@
|
|||
|
||||
### A
|
||||
|
||||
- Arbitrage: the simultaneous buying and selling of assets (e.g., cryptocurrencies) in several markets to take advantage of their price discrepancies.
|
||||
- Assets under management (AUM): the total market value of the investments that a person or entity manages on behalf of clients.
|
||||
- Arbitrage: the simultaneous buying and selling of assets (e.g., cryptocurrencies) in several markets to take advantage
|
||||
of their price discrepancies.
|
||||
- Assets under management (AUM): the total market value of the investments that a person or entity manages on behalf of
|
||||
clients.
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
### B
|
||||
|
||||
- Backrunning: when an attacker attempts to have a transaction ordered immediately after a certain unconfirmed target transaction.
|
||||
- Blocks: a block contains transaction data and the hash of the previous block ensuring immutability in the blockchain network. Each block in a blockchain contains a list of transactions in a particular order. These transactions encode the updates to the blockchain state.
|
||||
- Backrunning: when an attacker attempts to have a transaction ordered immediately after a certain unconfirmed target
|
||||
transaction.
|
||||
- Blocks: a block contains transaction data and the hash of the previous block ensuring immutability in the blockchain
|
||||
network. Each block in a blockchain contains a list of transactions in a particular order. These transactions encode the
|
||||
updates to the blockchain state.
|
||||
- Block time: the time interval between blocks being added to the blockchain.
|
||||
- Broadcasting: whenever a user interacts with the blockchain, they broadcast a request to include the transaction to the network. This request is public (anyone can listen to it).
|
||||
- Builders: actors that take bundles (of pendent transactions from the mempool) and create a final block to send to (multiple) relays (setting themselves afeeRecipient to receive the block’s MEV).
|
||||
- Bundles: one or more transactions that are grouped together and executed in the order they are provided. In addition to the searcher's transaction(s), a bundle can also contain other users' pending transactions from the mempool. Bundles can target specific blocks for inclusion as well.
|
||||
- Broadcasting: whenever a user interacts with the blockchain, they broadcast a request to include the transaction to
|
||||
the network. This request is public (anyone can listen to it).
|
||||
- Builders: actors that take bundles (of pendent transactions from the mempool) and create a final block to send to
|
||||
(multiple) relays (setting themselves afeeRecipient to receive the block’s MEV).
|
||||
- Bundles: one or more transactions that are grouped together and executed in the order they are provided.
|
||||
|
||||
<br>
|
||||
|
||||
### C
|
||||
|
||||
- Central limit order book (CLOB): patient buyers and sellers post limit orders with the price and size that they are willing to buy or sell a given asset. Impatient buyers and sellers place market orders that run through the CLOB until the desired size is reached.
|
||||
- Contract address: the address hosting some source code deployed on the Ethereum blockchain, which is executed by a triggering transaction.
|
||||
- Crypto copy trading strategy: a trading strategy that uses automation to buy and sell crypto, letting you copy another trader's method.
|
||||
- Central limit order book (CLOB): patient buyers and sellers post limit orders with the price and size that they are
|
||||
willing to buy or sell a given asset. Impatient buyers and sellers place market orders that run through the CLOB until
|
||||
the desired size is reached.
|
||||
- Contract address: the address hosting some source code deployed on the Ethereum blockchain, which is executed by a
|
||||
triggering transaction.
|
||||
- Crypto copy trading strategy: a trading strategy that uses automation to buy and sell crypto, letting you copy another
|
||||
trader's method.
|
||||
|
||||
<br>
|
||||
|
||||
### D
|
||||
|
||||
- Derivatives: financial contracts that derive their values from underlying assets.
|
||||
- Dollar-cost-averaging (DCA) strategy: a one-stop automated trading, based on time intervals, and reducing the influence of market volatility. Parameters for DCA can be: currency, fixed/maximum investment, and amount, investment frequency.
|
||||
- Dollar-cost-averaging (DCA) strategy: a one-stop automated trading, based on time intervals, and reducing the
|
||||
influence of market volatility. Parameters for DCA can be: currency, fixed/maximum investment, and amount, investment
|
||||
frequency.
|
||||
|
||||
<br>
|
||||
|
||||
### E
|
||||
|
||||
- Epoch: in the context of Ethereum's block production, in each slot (every 12 seconds), a validator is randomly chosen to propose the block in that slot. An epoch contains 32 slots.
|
||||
- Externally owned account (EOA): an account that is a combination of public address and private key, and that can be used to send and receive Ether to/from another account. An Ethereum address is a 42-character hexadecimal address derived from the last 20 bytes of the public key of the account (with 0x appended in front).
|
||||
- Epoch: in the context of Ethereum's block production, in each slot (every 12 seconds), a validator is randomly chosen
|
||||
to propose the block in that slot. An epoch contains 32 slots.
|
||||
- Externally owned account (EOA): an account that is a combination of public address and private key, and that can be
|
||||
used to send and receive Ether to/from another account. An Ethereum address is a 42-character hexadecimal address
|
||||
derived from the last 20 bytes of the public key of the account (with 0x appended in front).
|
||||
|
||||
<br>
|
||||
|
||||
### F
|
||||
|
||||
- Frontrunning: the process by which an adversary observes transactions on the network layer and acts on this information to obtain profit.
|
||||
- Frontrunning: the process by which an adversary observes transactions on the network layer and acts on this
|
||||
information to obtain profit.
|
||||
- Fully diluted valuations (FDV): the total number of tokens multiplied by the current price of a single token.
|
||||
- Futures: contracts used as proxy tools to speculate on the future prices of crypto assets or to hedge against their price changes.
|
||||
- Future grid trading bots: bots that automate futures trading activities based on grid trading strategies (a set of orders is placed both above and below a specific reference market price for the asset).
|
||||
- Futures: contracts used as proxy tools to speculate on the future prices of crypto assets or to hedge against their
|
||||
price changes.
|
||||
- Future grid trading bots: bots that automate futures trading activities based on grid trading strategies (a set of
|
||||
orders is placed both above and below a specific reference market price for the asset).
|
||||
|
||||
<br>
|
||||
|
||||
### G
|
||||
|
||||
- Gas price: used somewhat like a bid, indicating an amount the user is willing to pay (per unit of execution) to have their transaction processed.
|
||||
- Gwei: a small unit of the Ethereum network's Ether (ETH) cryptocurrency. A gwei or gigawei is defined as 1,000,000,000 wei, the smallest base unit of Ether. Conversely, 1 ETH represents 1 billion gwei.
|
||||
- Grid trading strategy: a strategy that involves placing orders above and below a set price, using a price grid of orders (which shows orders at incrementally increasing and decreasing prices). Grid trading is based on the overarching goal of buying low and selling high.
|
||||
- Gas price: used somewhat like a bid, indicating an amount the user is willing to pay (per unit of execution) to have
|
||||
their transaction processed.
|
||||
- Gwei: a small unit of the Ethereum network's Ether (ETH) cryptocurrency.
|
||||
- Grid trading strategy: a strategy that involves placing orders above and below a set price, using a price grid of
|
||||
orders (which shows orders at incrementally increasing and decreasing prices). Grid trading is based on the overarching
|
||||
goal of buying low and selling high.
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -75,10 +97,12 @@
|
|||
|
||||
### L
|
||||
|
||||
- Limit orders: when one longs or shorts a contract, several execution options can be placed (usually with a fee difference). Limit orders that are set at a specific price to be traded, and there is no guarantee that the trade will be executed (see market orders and stop-loss orders).
|
||||
- Liquidity pools: a collection of crypto assets that can be used for decentralized trading. They are essential for automated market makers (AMM), borrow-lend protocols, yield farming, synthetic assets, on-chain insurance, blockchain gaming, etc.
|
||||
- Limit orders: when one longs or shorts a contract, several execution options can be placed (usually with a fee
|
||||
difference). Limit orders that are set at a specific price to be traded, and there is no guarantee that the trade will
|
||||
be executed (see market orders and stop-loss orders).
|
||||
- Liquidity pools: a collection of crypto assets that can be used for decentralized trading.
|
||||
- Liquidation threshold: the percentage at which a collateral value is counted towards the borrowing capacity.
|
||||
- Liquidation: when the value of a borrowed asset exceeds the collateral. Anyone can liquidate the collateral and collect the liquidation fee for themselves.
|
||||
- Liquidation: when the value of a borrowed asset exceeds the collateral.
|
||||
- Long: traders maintain long positions, which means that they expect the price of a coin to rise in the future.
|
||||
|
||||
<br>
|
||||
|
@ -86,23 +110,32 @@
|
|||
### M
|
||||
|
||||
- Fully diluted market capitalization: the total token supply, multiplied by the price of a single token.
|
||||
- Circulating supply market capitalization: the number of tokens that are available in the market, multiplied by the price of a single token.
|
||||
- Circulating supply market capitalization: the number of tokens that are available in the market, multiplied by the
|
||||
price of a single token.
|
||||
- Margin trading: buying or sell assets with leverage.
|
||||
- Marginal seller: a type of seller who is willing first to leave the market if the prices are lower.
|
||||
- Market orders: Market orders are executed immediately at the asset's market price (see limit orders).
|
||||
- Mean reversion strategy: a trading range (or mean reversion) strategy is based on the concept that an asset's high and low prices are a temporary effect that reverts to their mean value (average value).
|
||||
- Mempool: a cryptocurrency node’s mechanism for storing information on unconfirmed transactions.
|
||||
- Merkle tree: a type of binary tree, composed of: 1) a set of notes with a large number of leaf nodes at the bottom, containing the underlying data, 2) a set of intermediate nodes where each node is the hash of its two children, and 3) a single root node, also formed from the hash of its two children, representing the top of the tree.
|
||||
- Minting: the process of validating information, creating a new block, and recording that information into the blockchain.
|
||||
- Mean reversion strategy: a trading range (or mean reversion) strategy is based on the concept that an asset's high and
|
||||
low prices are a temporary effect that reverts to their mean value (average value).
|
||||
- Mempool: a cryptocurrency node’s mechanism for storing information on unconfirmed transactions.
|
||||
- Merkle tree: a type of binary tree, composed of: 1) a set of notes with a large number of leaf nodes at the bottom,
|
||||
containing the underlying data, 2) a set of intermediate nodes where each node is the hash of its two children, and 3) a
|
||||
single root node, also formed from the hash of its two children, representing the top of the tree.
|
||||
- Minting: the process of validating information, creating a new block, and recording that information into the
|
||||
blockchain.
|
||||
|
||||
<br>
|
||||
|
||||
### P
|
||||
|
||||
- Perpetual contract: a contract without an expiration date, where interest rates can be calculated by methods such as Time-Weighted-Average-Price (TWAP).
|
||||
- Priority gas auctions: bots compete against each other by binding up transaction fees (gas) to extract revenue from arbitrage opportunities, driving up user fees.
|
||||
- Private key: a secret number enabling a blockchain user to prove ownership on an account or contract, via a digital signature.
|
||||
- Publick key: a number generated by a one-way (hash) function from the private key, used to verify a digital signature made with the matching private key.
|
||||
- Perpetual contract: a contract without an expiration date, where interest rates can be calculated by methods such as
|
||||
Time-Weighted-Average-Price (TWAP).
|
||||
- Priority gas auctions: bots compete against each other by binding up transaction fees (gas) to extract revenue from
|
||||
arbitrage opportunities, driving up user fees.
|
||||
- Private key: a secret number enabling a blockchain user to prove ownership on an account or contract, via a digital
|
||||
signature.
|
||||
- Publick key: a number generated by a one-way (hash) function from the private key, used to verify a digital signature
|
||||
made with the matching private key.
|
||||
- Provider: an entity that provides an abstraction for a connection to the blockchain network.
|
||||
- POFPs: private order flow protocols.
|
||||
|
||||
|
@ -110,8 +143,9 @@
|
|||
|
||||
### O
|
||||
|
||||
- Order flow: in the context of Ethereum and EVM-based blockchains, an order is anything that allows changing the state of the blockchain.
|
||||
- Open interest: total number of futures contracts held by market participants at the end of the trading day. Used as an indicator to determine market sentiment and the strength behind price trends.
|
||||
- Order flow: in the context of Ethereum and EVM-based blockchains, an order is anything that allows changing the state
|
||||
of the blockchain.
|
||||
- Open interest: total number of futures contracts held by market participants at the end of the trading day.
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -123,37 +157,48 @@
|
|||
|
||||
### S
|
||||
|
||||
- Slots: in the context of Ethereum's block production, a slot is a time period of 12 seconds in which a randomly chosen validator has time to propose a block.
|
||||
- Smart contracts: a computer protocol intended to enforce a contract on the blockchain without third parties. They are reliant upon code (the functions) and data (the state), and they can trigger specific actions, such as transferring tokens from A to B.
|
||||
- Sandwich attack: when slippage value is not set, this attack can happen by an actor bumping the price of an asset to an unfavorable level, executing the trade, and then returning the asset to the original price.
|
||||
- Slots: in the context of Ethereum's block production, a slot is a time period of 12 seconds in which a randomly chosen
|
||||
validator has time to propose a block.
|
||||
- Smart contracts: a computer protocol intended to enforce a contract on the blockchain without third parties.
|
||||
- Sandwich attack: when slippage value is not set, this attack can happen by an actor bumping the price of an asset to
|
||||
an unfavorable level, executing the trade, and then returning the asset to the original price.
|
||||
- Slippage: delta in pricing between the time of order and when the order is executed.
|
||||
- Short: traders maintain short positions, which means they expect the price of a coin to drop in the future.
|
||||
- Short squeeze: occurs when a heavily shorted stock experiences an increase in price for some unexpected reason. This situation prompts short sellers to scramble to buy the stock to cover their positions and cap their mounting losses.
|
||||
- Short squeeze: occurs when a heavily shorted stock experiences an increase in price for some unexpected reason.
|
||||
- Spot trading: buy or selling assets for immediate delivery.
|
||||
- Statistical trading: is the class of strategies that aim to generate profitable situations, stemming from pricing inefficiencies among financial markets. Statistical arbitrage is a strategy to obtain profit by applying past statistics.
|
||||
- Stop-loss orders: this type of order execution places a market/limit order to close a position to restrict an investor's loss on a crypto asset.
|
||||
- Statistical trading: is the class of strategies that aim to generate profitable situations, stemming from pricing
|
||||
inefficiencies among financial markets. Statistical arbitrage is a strategy to obtain profit by applying past
|
||||
statistics.
|
||||
- Stop-loss orders: this type of order execution places a market/limit order to close a position to restrict an
|
||||
investor's loss on a crypto asset.
|
||||
|
||||
<br>
|
||||
|
||||
### T
|
||||
|
||||
- otal value locked (TVL): the value of all tokens locked in various DeFi protocols such as lending platforms, DEXes, or derivatives protocols.
|
||||
- otal value locked (TVL): the value of all tokens locked in various DeFi protocols such as lending platforms, DEXes, or
|
||||
derivatives protocols.
|
||||
- Тrading volume: the total amount of traded cryptocurrency (equivalent to US dollars) during a given timeframe.
|
||||
- Transaction: on EVM-based blockchains, there the two types of transactions are normal transactions and contract interactions.
|
||||
- Transaction: on EVM-based blockchains, there the two types of transactions are normal transactions and contract
|
||||
interactions.
|
||||
- Transaction hash: a unique 66-character identifier generated with each new transaction.
|
||||
- Transaction ordering: blockchains usually have loose requirements for how transactions are ordered within a block, allowing attacks that benefit from certain ordering.
|
||||
- Time-weighted average price strategy: TWAP strategy breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using evenly divided time slots between a start and end time.
|
||||
- Transaction ordering: blockchains usually have loose requirements for how transactions are ordered within a block,
|
||||
allowing attacks that benefit from certain ordering.
|
||||
- Time-weighted average price strategy: TWAP strategy breaks up a large order and releases dynamically determined
|
||||
smaller chunks of the order to the market, using evenly divided time slots between a start and end time.
|
||||
|
||||
<br>
|
||||
|
||||
### V
|
||||
|
||||
- Validation: a mathematical proof that the state change in the blockchain is consistent. To be included into a block in the blockchain, a list of transactions needs to be validated.
|
||||
- Validation: a mathematical proof that the state change in the blockchain is consistent.
|
||||
- VTRPs: validator transaction Reordering protocols.
|
||||
- Volume-weighted average price strategy: VWAP breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using historical volume profiles.
|
||||
- Volume-weighted average price strategy: VWAP breaks up a large order and releases dynamically determined smaller
|
||||
chunks of the order to the market, using historical volume profiles.
|
||||
|
||||
<br>
|
||||
|
||||
### W
|
||||
|
||||
- Whales: individuals or institutions who hold large amounts of coins of a certain cryptocurrency, and can become powerful enough to manipulate the valuation.
|
||||
- Whales: individuals or institutions who hold large amounts of coins of a certain cryptocurrency, and can become
|
||||
powerful enough to manipulate the valuation.
|
||||
|
|
|
@ -2,5 +2,6 @@
|
|||
|
||||
<br>
|
||||
|
||||
* perform a search, for example grid search, over possible values of strategy parameters like thresholds or coefficients (using the simulator and a set of historical data)
|
||||
* perform a search, for example grid search, over possible values of strategy parameters like thresholds or coefficients
|
||||
(using the simulator and a set of historical data)
|
||||
* overfitting to historical data is a big risk (be careful with validation and test sets).
|
||||
|
|
|
@ -2,4 +2,5 @@
|
|||
|
||||
<br>
|
||||
|
||||
* before the strategy goes live, simulation is done on new market data, in real-time (paper trading), which prevents overfitting
|
||||
* before the strategy goes live, simulation is done on new market data, in real-time (paper trading), which prevents
|
||||
overfitting
|
||||
|
|
|
@ -2,4 +2,5 @@
|
|||
|
||||
<br>
|
||||
|
||||
* come with a rule-based policy that determines what actions to take based on the current state of the market and the outpus of supervised models.
|
||||
* come with a rule-based policy that determines what actions to take based on the current state of the market and the
|
||||
outputs of supervised models.
|
||||
|
|
|
@ -2,8 +2,12 @@
|
|||
|
||||
<br>
|
||||
|
||||
* **net pnl (net profit and loss):** how much money an algorithm makes (positive) or loses (negative) over some period, minus trading fees
|
||||
* **net pnl (net profit and loss):** how much money an algorithm makes (positive) or loses (negative) over some period,
|
||||
minus trading fees
|
||||
* **alpha nad beta**
|
||||
* **shape ratio:** the excess return per unit of risk you are taking (return on capital over the standard deviation adjusted for risk; the higher the better).
|
||||
* **maximum drawdown:** maximum difference between a local maximum and a subsequent local minimum as an another measure of risk.
|
||||
* **value at risk (var):** how much capital you may lose over a given time frame with some probability, assumong normal market conditions.
|
||||
* **shape ratio:** the excess return per unit of risk you are taking (return on capital over the standard deviation
|
||||
adjusted for risk; the higher the better).
|
||||
* **maximum drawdown:** maximum difference between a local maximum and a subsequent local minimum as an another measure
|
||||
of risk.
|
||||
* **value at risk (var):** how much capital you may lose over a given time frame with some probability, assumong normal
|
||||
market conditions.
|
||||
|
|
|
@ -2,4 +2,5 @@
|
|||
|
||||
<br>
|
||||
|
||||
* train one or more supervised learning models to predict quantities of interest that are necessary for the strategy work, for example, price prediction, quantity prediction, etc.
|
||||
* train one or more supervised learning models to predict quantities of interest that are necessary for the strategy
|
||||
work, for example, price prediction, quantity prediction, etc.
|
||||
|
|
|
@ -13,8 +13,10 @@
|
|||
<br>
|
||||
|
||||
|
||||
<img width="400" src="https://user-images.githubusercontent.com/1130416/227733463-d0dff53f-9a5f-45f3-80a4-9d9ab0d9201e.png">
|
||||
<img width="400" src="https://user-images.githubusercontent.com/1130416/227733575-90550afd-99f2-45cc-b6aa-fd4457910cc5.png">
|
||||
<img width="400"
|
||||
src="https://user-images.githubusercontent.com/1130416/227733463-d0dff53f-9a5f-45f3-80a4-9d9ab0d9201e.png">
|
||||
<img width="400"
|
||||
src="https://user-images.githubusercontent.com/1130416/227733575-90550afd-99f2-45cc-b6aa-fd4457910cc5.png">
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -25,10 +27,12 @@
|
|||
<br>
|
||||
|
||||
* the order book is made of two sides, asks (sell, offers) and bids (buy).
|
||||
* the best ask (the lowest price someone is willing to sell ) > the best bid (the highest price someone is willing to buy).
|
||||
* the best ask (the lowest price someone is willing to sell ) > the best bid (the highest price someone is willing to
|
||||
buy).
|
||||
* the difference between the best ask and the best bid is called spread.
|
||||
* **market order**: best price possible, right now. it takes liquidity from the market and usually has higher fees.
|
||||
* **limit order (passive order)**: specify the price and qty you are willing to buy or sell at, and then wait for the match.
|
||||
* **limit order (passive order)**: specify the price and qty you are willing to buy or sell at, and then wait for the
|
||||
match.
|
||||
* **stop orders**: allow you to set a maximum price for your market orders.
|
||||
|
||||
<br>
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
## ai agents
|
||||
## some machine learning history
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -13,8 +13,6 @@
|
|||
|
||||
<br>
|
||||
|
||||
* **[cursor ai editor](https://www.cursor.com/)**
|
||||
* **[microsoft notes on ai agents](https://github.com/microsoft/generative-ai-for-beginners/tree/main/17-ai-agents)**
|
||||
* **[google's jax (composable transformations of numpy programs)](https://github.com/google/jax)**
|
||||
* **[machine learning engineering open book](https://github.com/stas00/ml-engineering)**
|
||||
* **[advances in financial machine learning](books/advances_in_financial_machine_learning.pdf)**
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
## deep learning
|
||||
## deep learning
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -10,9 +10,11 @@
|
|||
|
||||
* **[2013: atari with deep reinforcement learning](https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial)**
|
||||
* **[2014: seq2seq](https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt)**
|
||||
* **[2014: adam optmizer](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L281)**
|
||||
* **[2014: adam
|
||||
optmizer](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L281)**
|
||||
* **[2015: gans](https://www.tensorflow.org/tutorials/generative/dcgan)**
|
||||
* **[2015: resnets](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/applications/resnet.py)**
|
||||
* **[2015:
|
||||
resnets](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/applications/resnet.py)**
|
||||
* **[2017: transformers](https://github.com/huggingface/transformers)**
|
||||
* **[2018: bert](https://arxiv.org/abs/1810.04805)**
|
||||
|
||||
|
@ -24,30 +26,38 @@
|
|||
|
||||
<br>
|
||||
|
||||
* a map consists of a set of states, a set of actions, a transition function that describes the probability of moving rom one state to another after taking an action, and a reward function that assigns a numerical reward to each state-action pair
|
||||
* a map consists of a set of states, a set of actions, a transition function that describes the probability of moving
|
||||
rom one state to another after taking an action, and a reward function that assigns a numerical reward to each
|
||||
state-action pair
|
||||
|
||||
* the goal of a map is to maximize its expected cumulative reward over a sequence of actions, called a policy.
|
||||
|
||||
* a policy is a function that maps each state to a probability distribution over actions. The optimal policy is the one that maximizes the expected cumulative rewards.
|
||||
* a policy is a function that maps each state to a probability distribution over actions.
|
||||
|
||||
* the problem of reinforcement learning can be formalized using ideas from dynamical systems theory, specifically, as the optimal control of incompletely-known Markov decision processes.
|
||||
* the problem of reinforcement learning can be formalized using ideas from dynamical systems theory, specifically, as
|
||||
the optimal control of incompletely-known Markov decision processes.
|
||||
|
||||
* as opposed to supervised learning, an agent must be able to learn from its own experience. and as oppose to unsupervised learning because, reinforcement learning is trying to maximize a reward signal instead of trying to find hidden structure.
|
||||
* as opposed to supervised learning, an agent must be able to learn from its own experience.
|
||||
|
||||
* the agent has to exploit what it has already experienced in order to obtain reward, but it also has to explore in order to make better action selections in the future. on a stochastic task, each action must be tried many times to gain a reliable estimate of its expected reward.
|
||||
* the agent has to exploit what it has already experienced in order to obtain reward, but it also has to explore in
|
||||
order to make better action selections in the future. on a stochastic task, each action must be tried many times to gain
|
||||
a reliable estimate of its expected reward.
|
||||
|
||||
* beyond the agent and the environment, one can identify four main subelements of a reinforcement learning system: a policy, a reward signal, a value function, and, optionally, a model of the environment.
|
||||
* beyond the agent and the environment, one can identify four main subelements of a reinforcement learning system: a
|
||||
policy, a reward signal, a value function, and, optionally, a model of the environment.
|
||||
|
||||
* traditional reinforcement learning problems can be formulated as a markov decision process (MDP):
|
||||
* traditional reinforcement learning problems can be formulated as a markov decision process (MDP):
|
||||
* we have an agent acting in an environment
|
||||
* each step *t* the agent receives as the input the current state S_t, takes an action A_t, and receives a reward R_{t+1} and the next state S_{t+1}
|
||||
* each step *t* the agent receives as the input the current state S_t, takes an action A_t, and receives a reward
|
||||
R_{t+1} and the next state S_{t+1}
|
||||
* the agent choose the action based on some policy pi: A_t = pi(S_t)
|
||||
* it's our goal to find a policy that maximizes the cumulative reward Sum R_t over some finite or infinite time horizon
|
||||
* it's our goal to find a policy that maximizes the cumulative reward Sum R_t over some finite or infinite time horizon
|
||||
|
||||
|
||||
<br>
|
||||
|
||||
<img width="500" src="https://user-images.githubusercontent.com/1130416/227799494-d62aab7f-d6cf-419f-be03-1d2dbdee1853.png">
|
||||
<img width="500"
|
||||
src="https://user-images.githubusercontent.com/1130416/227799494-d62aab7f-d6cf-419f-be03-1d2dbdee1853.png">
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -55,7 +65,7 @@
|
|||
|
||||
<br>
|
||||
|
||||
* agent is the trading agent (e.g. the human trader who opens the gui of an exchange and makes trading decision based on the current state of the exchange and their account)
|
||||
* agent is the trading agent (e.g.
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -65,7 +75,8 @@
|
|||
|
||||
* the exchange and other agents are the environment, and they are not something we can control
|
||||
* by putting other agents together into some big complex environment, we lose the ability to explicitly model them
|
||||
* if we try to reverse-engineer the algorithms and strategies that other traders are running, put us into a multi-agent reinforcement learning (MARL) problem setting
|
||||
* if we try to reverse-engineer the algorithms and strategies that other traders are running, put us into a multi-agent
|
||||
reinforcement learning (MARL) problem setting
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -73,11 +84,12 @@
|
|||
|
||||
<br>
|
||||
|
||||
* in the case of trading on an exchange, we don't observe the complete state of the environment (e.g. other agents), so we are dealing with a partially observable markov decision process (pomdp).
|
||||
* in the case of trading on an exchange, we don't observe the complete state of the environment (e.g.
|
||||
* what the agents observe is not the actual state S_t of the environment, but some derivation of that.
|
||||
* we can call the observation X_t, which is calculated using some function of the full state X_t ~ O(S_t)
|
||||
* the observation at each timestep t is simply the history of all exchange events received up to time t.
|
||||
* this event history can be used to build up the current exchange state, however, in order for our agent to make decisions, extra info such as account balance and open limit orders need to be included.
|
||||
* this event history can be used to build up the current exchange state, however, in order for our agent to make
|
||||
decisions, extra info such as account balance and open limit orders need to be included.
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -85,8 +97,9 @@
|
|||
|
||||
<br>
|
||||
|
||||
* hft techniques: decisions are based almost entirely on market microstructure signals. decisions are made on nanoseconds timescales and trading strategies use dedicated connections to exchanges and extremly fast but simple algorithms running fpga hardware.
|
||||
* neural networks are slow, they can't make predictions on nanoseconds time scales, so they can't compete with the speed of hft algorithms.
|
||||
* hft techniques: decisions are based almost entirely on market microstructure signals.
|
||||
* neural networks are slow, they can't make predictions on nanoseconds time scales, so they can't compete with the speed
|
||||
of hft algorithms.
|
||||
* guess: the optimal time scale is between a few milliseconds and a few minutes.
|
||||
* can deep rl algorithms pick up hidden patterns?
|
||||
|
||||
|
@ -96,9 +109,11 @@
|
|||
|
||||
<br>
|
||||
|
||||
* the simplest approach has 3 actions: buy, hold, and sell. this works but limits us to placing market orders and to invest a deterministic amount of money at each step.
|
||||
* in the next level we would let our agents learn how much money to invest, based on the uncertainty of our model, putting us into a continuous action space.
|
||||
* in the next level, we would introduce limit orders, and the agent needs to decide the level (price) and wuantity of the order, and be able to cancel orders that have not been yet matched.
|
||||
* the simplest approach has 3 actions: buy, hold, and sell.
|
||||
* in the next level we would let our agents learn how much money to invest, based on the uncertainty of our model,
|
||||
putting us into a continuous action space.
|
||||
* in the next level, we would introduce limit orders, and the agent needs to decide the level (price) and wuantity of
|
||||
the order, and be able to cancel orders that have not been yet matched.
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -106,18 +121,21 @@
|
|||
|
||||
<br>
|
||||
|
||||
* there are several possible reward functions, an obvious would realized PnL (profit and loss). the agent receives a reward whenever it closes a position.
|
||||
* there are several possible reward functions, an obvious would realized PnL (profit and loss).
|
||||
* the net profit is either negative or positive, and this is the reward signal.
|
||||
* as the agent maximize the total cumulative reward, it learns to trade profitably. the reward function leads to the optimal policy in the limit.
|
||||
* however, buy and sell actions are rare compared to doing nothing; the agent needs to learn without receiving frequent feedback.
|
||||
* an alternative is unrealized pnl, which the net profit the agent would get if it were to close all of its positions immediately.
|
||||
* because the unrealized pnl may change at each time step, it gives the agent more frequent feedback signals. however the direct feedback may bias the agent towards short-term actions.
|
||||
* as the agent maximize the total cumulative reward, it learns to trade profitably.
|
||||
* however, buy and sell actions are rare compared to doing nothing; the agent needs to learn without receiving frequent
|
||||
feedback.
|
||||
* an alternative is unrealized pnl, which the net profit the agent would get if it were to close all of its positions
|
||||
immediately.
|
||||
* because the unrealized pnl may change at each time step, it gives the agent more frequent feedback signals.
|
||||
* both naively optimize for profit, but a trader may want to minimize risk (lower volatility)
|
||||
* using the sharpe ration is one simple way to take risk into account. other way is maximum drawdown.
|
||||
|
||||
<br>
|
||||
|
||||
<img width="505" src="https://user-images.githubusercontent.com/1130416/227811225-9af06c79-3f86-48e8-899c-ee5a80bc91e1.png">
|
||||
<img width="505"
|
||||
src="https://user-images.githubusercontent.com/1130416/227811225-9af06c79-3f86-48e8-899c-ee5a80bc91e1.png">
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -134,11 +152,14 @@
|
|||
|
||||
<br>
|
||||
|
||||
* we need separate backtesting and parameter optimization steps because it was difficult for our strategies to take into account environmental factors: order book liquidity, fee structures, latencies.
|
||||
* getting around environmental limitations is part of the opimization process. if we simulate the latency in the reinforcement learning environment, and this results in the agent making a mistake, the agent will get a negative rewards, forcing it to learn to work around the latencies.
|
||||
* by learning a model of the environment and performing rollouts using techniques like a monte carlo tree search (mcts), we could take into account potential reactions of the market (other agents)
|
||||
* we need separate backtesting and parameter optimization steps because it was difficult for our strategies to take into
|
||||
account environmental factors: order book liquidity, fee structures, latencies.
|
||||
* getting around environmental limitations is part of the opimization process.
|
||||
* by learning a model of the environment and performing rollouts using techniques like a monte carlo tree search (mcts),
|
||||
we could take into account potential reactions of the market (other agents)
|
||||
* by being smart about the data we collect from the live environment, we can continously improve our model
|
||||
* do we act optimally in the live environment to generate profits, or do we act suboptimally to gather interesting information that we can use to improve the model of our environment and other agents?
|
||||
* do we act optimally in the live environment to generate profits, or do we act suboptimally to gather interesting
|
||||
information that we can use to improve the model of our environment and other agents?
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -147,7 +168,9 @@
|
|||
<br>
|
||||
|
||||
* some strategy may work better in a bearish environment but lose money in a bullish environment.
|
||||
* because rl agents are learning powerful policies parameterized by NN, they can alos learn to adapt to market conditions by seeing them in historical data, given that they are trained over long time horizon and have sufficient memory.
|
||||
* because rl agents are learning powerful policies parameterized by NN, they can alos learn to adapt to market
|
||||
conditions by seeing them in historical data, given that they are trained over long time horizon and have sufficient
|
||||
memory.
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -156,9 +179,13 @@
|
|||
<br>
|
||||
|
||||
* the trading environment is a multiplayer game with thousands of agents acting simultaneously
|
||||
* understanding how to build models of other agents is only one possible we can, we can choose perfom actions in a live environment with the goal of maximizing the information grain with respect to kind policies the other agents may be following
|
||||
* understanding how to build models of other agents is only one possible we can, we can choose perfom actions in a live
|
||||
environment with the goal of maximizing the information grain with respect to kind policies the other agents may be
|
||||
following
|
||||
* trading agents receive sparse rewards from the market. naively applying reward-hungry rl algorithms will fail.
|
||||
* this opens up the possibility for new algorithms and techniques, that can efficiently deal with sparse rewards.
|
||||
* many of today's standard algorithms, such as dqn or a3c, use a very naive approach exploration - basically adding random noise to the policy. however, in the trading case, most states in the environment are bad, and there are only a few good ones. a naive random approach to exploration will almost never stumble upon good state-actions paris.
|
||||
* the trading environment is inherently nonstationary. market conditions change and other agent join, leave, and constantly change their strategies.
|
||||
* many of today's standard algorithms, such as dqn or a3c, use a very naive approach exploration - basically adding
|
||||
random noise to the policy. however, in the trading case, most states in the environment are bad, and there are only a
|
||||
few good ones. a naive random approach to exploration will almost never stumble upon good state-actions paris.
|
||||
* the trading environment is inherently nonstationary.
|
||||
* can we train an agent that can transit from bear to bull and then back to bear, without needing to be re-trained?
|
||||
|
|
|
@ -6,8 +6,10 @@
|
|||
|
||||
<br>
|
||||
|
||||
* reinforcement learning is learning what to do (how to map situations to actions) so as to maximize a numerical reward signal
|
||||
* an autonomous agent is a software program or system that can operate independently and make decisions on its own, without direct intervention from a human
|
||||
* reinforcement learning is learning what to do (how to map situations to actions) so as to maximize a numerical reward
|
||||
signal
|
||||
* an autonomous agent is a software program or system that can operate independently and make decisions on its own,
|
||||
without direct intervention from a human
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -17,10 +19,13 @@
|
|||
|
||||
<br>
|
||||
|
||||
* we formalize the problem of reinforcement using ideas from dynamical system theory, as the optimal control of incompletely-known Markov decision processes.
|
||||
* a learning agent must be able to sense the state of its environment to some extent and must be able to take actions that affect the state.
|
||||
* we formalize the problem of reinforcement using ideas from dynamical system theory, as the optimal control of
|
||||
incompletely-known Markov decision processes.
|
||||
* a learning agent must be able to sense the state of its environment to some extent and must be able to take actions
|
||||
that affect the state.
|
||||
* markov decision processes are intented to include just these three aspects, sensation, action, and goal.
|
||||
* the agent has to exploit what it has already experienced in order to obtain reward, but it has also to explore in order to make better action selections in the future.
|
||||
* the agent has to exploit what it has already experienced in order to obtain reward, but it has also to explore in
|
||||
order to make better action selections in the future.
|
||||
* on a stochastic tasks, each action must be tried many times to gain a reliable estimate of its expected reward.
|
||||
|
||||
<br>
|
||||
|
@ -31,12 +36,17 @@
|
|||
|
||||
<br>
|
||||
|
||||
* beyond the agent and the environment, 4 more elements belong to a reinforcement learning system: a policy, a reward signal, a value funtion, and a model of the environmnet.
|
||||
* a policy defines the learning agent's way of behacing at a given time. It's a mapping from perceiv ed states of the environment to actions to be taken when in those states. in general, policies may be stochastics (specifying probabilities for each action).
|
||||
* a reward signal defines the goal of a reinforcement learning problem: on each time step, the environment sends to the reinforcement learning agent a single number called the reward. the agent's sole objective is to maximize the total reward over the run.
|
||||
* a value function specifies what is good in the long run, the valye of a state in the total amount of reward an agent can expect to accumulate over the future, starting from that state
|
||||
* beyond the agent and the environment, 4 more elements belong to a reinforcement learning system: a policy, a reward
|
||||
signal, a value funtion, and a model of the environmnet.
|
||||
* a policy defines the learning agent's way of behacing at a given time.
|
||||
* a reward signal defines the goal of a reinforcement learning problem: on each time step, the environment sends to the
|
||||
reinforcement learning agent a single number called the reward. the agent's sole objective is to maximize the total
|
||||
reward over the run.
|
||||
* a value function specifies what is good in the long run, the valye of a state in the total amount of reward an agent
|
||||
can expect to accumulate over the future, starting from that state
|
||||
* a model of the environment.
|
||||
* the most important feature distinguishing reinforcement learning from other types of learning is that it uses training information that evaluates the actions taken rather than instructs by giving correct actions.
|
||||
* the most important feature distinguishing reinforcement learning from other types of learning is that it uses training
|
||||
information that evaluates the actions taken rather than instructs by giving correct actions.
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -47,7 +57,8 @@
|
|||
<br>
|
||||
|
||||
* the problem involves evaluating feedbacks and choosing different actions in different situations.
|
||||
* mdps are a classical formalization of sequential decision making, where actions influence not just immediate rewards, but also subsequent situations.
|
||||
* mdps are a classical formalization of sequential decision making, where actions influence not just immediate rewards,
|
||||
but also subsequent situations.
|
||||
* mdps involve delayed reward and the need to trade off immediate and delayed reward.
|
||||
|
||||
<br>
|
||||
|
@ -57,34 +68,42 @@
|
|||
* mdps are meant to be a straightfoward framing of the problem of learning from interaction to achieve a goal.
|
||||
* the learner and the decision makers is called the agent.
|
||||
* the thing it interacts with, comprimising everything outside the agent, is called the environment.
|
||||
* the environment gives rise to rewards, numerical values that the agent seeks to maximize over time through its choice of actions.
|
||||
* the environment gives rise to rewards, numerical values that the agent seeks to maximize over time through its choice
|
||||
of actions.
|
||||
|
||||
<br>
|
||||
|
||||
<img width="466" src="https://user-images.githubusercontent.com/1130416/228971927-3c574911-d0ca-4d2d-b795-8b0776599952.png">
|
||||
<img width="466"
|
||||
src="https://user-images.githubusercontent.com/1130416/228971927-3c574911-d0ca-4d2d-b795-8b0776599952.png">
|
||||
|
||||
<br>
|
||||
|
||||
* the agent and the environment interact at each of a sequence of discrete steps, t = 0, 1, 2, 3...
|
||||
* at each time step t, the agent receives some representation of the environments state St
|
||||
* on that basis, the agent selects an action At
|
||||
* one step later, in part of a consequence of its action, the agent receives a numerical rewards and finds itself in a new state.
|
||||
* one step later, in part of a consequence of its action, the agent receives a numerical rewards and finds itself in a
|
||||
new state.
|
||||
* the mdp and the agent together give rise to a sequence (trajectory)
|
||||
* in a finite mdp, the set of states, actions, and rewards all have a finite number of elements. in this case, the random variables R and S have well defined discrete probability distributions dependent only on the proceding state and action.
|
||||
* in a finite mdp, the set of states, actions, and rewards all have a finite number of elements.
|
||||
* in a markov decision process, the probabilities given by p completely characterize the environment's dynamics.
|
||||
* the state must include information about all aspects of the past agent-environment interaction that make a differnce for the future.
|
||||
* anything that cannot be changed arbitrarily by the agent is considered to be outside of it and thus part of its environment.
|
||||
* the state must include information about all aspects of the past agent-environment interaction that make a differnce
|
||||
for the future.
|
||||
* anything that cannot be changed arbitrarily by the agent is considered to be outside of it and thus part of its
|
||||
environment.
|
||||
|
||||
<br>
|
||||
|
||||
##### goals and rewards
|
||||
|
||||
|
||||
* each episode ends in a special state called the terminal state, followed by a reset to a standard starting state or to a sample from a standard distribution of starting states.
|
||||
* almost all reinforcement learning algorithms involve estimating value functions—functions of states (or of state–action pairs) that estimate how good it is for the agent to be in a given state (or how good it is to perform a given action in a given state).
|
||||
* the Bellman equation averages over all the possibilities, weighting each by its probability of occurring. tt states that the value of the start state must equal the
|
||||
* each episode ends in a special state called the terminal state, followed by a reset to a standard starting state or to
|
||||
a sample from a standard distribution of starting states.
|
||||
* almost all reinforcement learning algorithms involve estimating value functions—functions of states (or of
|
||||
state–action pairs) that estimate how good it is for the agent to be in a given state (or how good it is to perform a
|
||||
given action in a given state).
|
||||
* the Bellman equation averages over all the possibilities, weighting each by its probability of occurring.
|
||||
(discounted) value of the expected next state, plus the reward expected along the way.
|
||||
* solving a reinforcement learning task means finding a policy that achieves a lot of reward over the long run.
|
||||
* solving a reinforcement learning task means finding a policy that achieves a lot of reward over the long run.
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -92,20 +111,28 @@
|
|||
|
||||
### dynamic programming
|
||||
|
||||
* collection of algorithms that can be used to compute optimal policies given a perfect model of the environment as a mdp.
|
||||
* a common way of obtaining approximate solutions for tasks with continuous states and actions is to quantize the state and action spaces and then apply finite-state DP methods.
|
||||
* collection of algorithms that can be used to compute optimal policies given a perfect model of the environment as a
|
||||
mdp.
|
||||
* a common way of obtaining approximate solutions for tasks with continuous states and actions is to quantize the state
|
||||
and action spaces and then apply finite-state DP methods.
|
||||
* the reason for computing the value function for a policy is to help find better policies.
|
||||
* asynchronous DP algorithms are in-place iterative DP algorithms that are not organized in terms of systematic sweeps of the state set. these algorithms update the values of states in any order whatsoever, using whatever values of other states happen to be available. the values of some states may be updated several times before the values of others ar
|
||||
* policy evaluation refers to the (typi- cally) iterative computation of the value function for a given policy.
|
||||
* asynchronous DP algorithms are in-place iterative DP algorithms that are not organized in terms of systematic sweeps
|
||||
of the state set. these algorithms update the values of states in any order whatsoever, using whatever values of other
|
||||
states happen to be available. the values of some states may be updated several times before the values of others ar
|
||||
* policy evaluation refers to the (typi- cally) iterative computation of the value function for a given policy.
|
||||
* policy improvement refers to the computation of an improved policy given the value function for that policy.
|
||||
|
||||
<br>
|
||||
|
||||
##### generalized policy interaction
|
||||
|
||||
* policy iteration consists of two simultaneous, interacting processes, one making the value function consistent with the current policy (policy evaluation), and the other making the policy greedy with respect to the current value function (policy improvement).
|
||||
* generalized policy iteration (GPI) refers to the general idea of letting policy-evaluation and policy-improvement processes interact, independent of the granularity and other details of the two processes.
|
||||
* DP is sometimes thought to be of limited applicability because of the curse of dimen- sionality, the fact that the number of states often grows exponentially with the number of state variables
|
||||
* policy iteration consists of two simultaneous, interacting processes, one making the value function consistent with
|
||||
the current policy (policy evaluation), and the other making the policy greedy with respect to the current value
|
||||
function (policy improvement).
|
||||
* generalized policy iteration (GPI) refers to the general idea of letting policy-evaluation and policy-improvement
|
||||
processes interact, independent of the granularity and other details of the two processes.
|
||||
* DP is sometimes thought to be of limited applicability because of the curse of dimen- sionality, the fact that the
|
||||
number of states often grows exponentially with the number of state variables
|
||||
|
||||
<br>
|
||||
|
||||
|
@ -116,4 +143,4 @@
|
|||
<br>
|
||||
|
||||
* **[gymnasium api](https://gymnasium.farama.org/)**
|
||||
* **[reinforcement learning with unsupervised auxiliary tasks, by jaderberg et al.](https://arxiv.org/abs/1611.05397)**
|
||||
* **[reinforcement learning with unsupervised auxiliary tasks, by jaderberg et al.](https://arxiv.org/abs/1611.05397)**
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
<br>
|
||||
|
||||
* **[opeanai](opeanai)**
|
||||
* **[google's gemini](gemini)**
|
||||
* **[openAI](openAI)**
|
||||
* **[claude](claude)**
|
||||
* **[gemini](gemini)**
|
||||
* **[deepseek](deepseek)**
|
||||
|
||||
<br>
|
||||
|
@ -17,7 +17,7 @@
|
|||
|
||||
#### articles
|
||||
|
||||
* **[people cannot distinguish gpt-4 from a human in a turing test, by c. jones et al (2024)](https://arxiv.org/pdf/2405.08007)**
|
||||
* **[people cannot distinguish gpt-4 from a human in a turing test, by c.
|
||||
|
||||
<br>
|
||||
|
||||
|
|
|
@ -2,6 +2,12 @@
|
|||
|
||||
<br>
|
||||
|
||||
|
||||
|
||||
<br>
|
||||
|
||||
---
|
||||
|
||||
### cool resources
|
||||
|
||||
<br>
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
<br>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/user-attachments/assets/42b8c4ac-4359-422a-a0a4-dd4ff0ec6e75" width="60%" align="center" style="padding:1px;border:1px solid black;" />
|
||||
<img src="https://github.com/user-attachments/assets/a1b2b912-8700-439f-8ad8-db415c94ad0b" width="60%" align="center" style="padding:1px;border:1px solid black;" />
|
||||
<img src="https://github.com/user-attachments/assets/42b8c4ac-4359-422a-a0a4-dd4ff0ec6e75" width="60%" align="center"
|
||||
style="padding:1px;border:1px solid black;" />
|
||||
<img src="https://github.com/user-attachments/assets/a1b2b912-8700-439f-8ad8-db415c94ad0b" width="60%" align="center"
|
||||
style="padding:1px;border:1px solid black;" />
|
||||
</p>
|
||||
|
|
|
@ -1 +1 @@
|
|||
## gemini
|
||||
## gemini
|
||||
|
|
|
@ -1,16 +1,25 @@
|
|||
## openai
|
||||
## openAI
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
|
||||
<br>
|
||||
|
||||
---
|
||||
|
||||
### cool resources
|
||||
|
||||
<br>
|
||||
|
||||
* **[vscode chatgpt plugin](https://github.com/mpociot/chatgpt-vscode) (and [here](https://marketplace.visualstudio.com/items?itemName=timkmecl.chatgpt))**
|
||||
* **[scispace extension (paper explainer)](https://chrome.google.com/webstore/detail/scispace-copilot/cipccbpjpemcnijhjcdjmkjhmhniiick/related)**
|
||||
* **[vscode chatgpt plugin](https://github.com/mpociot/chatgpt-vscode)**
|
||||
* **[scispace extension (paper
|
||||
explainer)](https://chrome.google.com/webstore/detail/scispace-copilot/cipccbpjpemcnijhjcdjmkjhmhniiick/related)**
|
||||
* **[fix python bugs](https://platform.openai.com/playground/p/default-fix-python-bugs?model=code-davinci-002)**
|
||||
* **[explain code](https://platform.openai.com/playground/p/default-explain-code?model=code-davinci-002)**
|
||||
* **[translate code](https://platform.openai.com/playground/p/default-translate-code?model=code-davinci-002)**
|
||||
* **[translate sql](https://platform.openai.com/playground/p/default-sql-translate?model=code-davinci-002)**
|
||||
* **[calculate time complexity](https://platform.openai.com/playground/p/default-time-complexity?model=text-davinci-003)**
|
||||
* **[text to programmatic command](https://platform.openai.com/playground/p/default-text-to-command?model=text-davinci-003)**
|
||||
* **[calculate time
|
||||
complexity](https://platform.openai.com/playground/p/default-time-complexity?model=text-davinci-003)**
|
||||
* **[text to programmatic
|
||||
command](https://platform.openai.com/playground/p/default-text-to-command?model=text-davinci-003)**
|
||||
|
|
68
pyproject.toml
Normal file
68
pyproject.toml
Normal file
|
@ -0,0 +1,68 @@
|
|||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py39']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
# directories
|
||||
\.eggs
|
||||
| \.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| venv
|
||||
| _build
|
||||
| buck-out
|
||||
| build
|
||||
| dist
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
multi_line_output = 3
|
||||
line_length = 88
|
||||
known_first_party = []
|
||||
known_third_party = []
|
||||
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
|
||||
skip = ["venv", ".venv", "__pycache__"]
|
||||
|
||||
[tool.autopep8]
|
||||
max_line_length = 88
|
||||
aggressive = 2
|
||||
experimental = true
|
||||
|
||||
[tool.autoflake]
|
||||
remove-all-unused-imports = true
|
||||
remove-unused-variables = true
|
||||
remove-duplicate-keys = true
|
||||
ignore-init-module-imports = true
|
||||
|
||||
[tool.flake8]
|
||||
max-line-length = 88
|
||||
extend-ignore = ["E203", "W503"]
|
||||
exclude = [
|
||||
".git",
|
||||
"__pycache__",
|
||||
"venv",
|
||||
".venv",
|
||||
"build",
|
||||
"dist",
|
||||
".eggs"
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.9"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
warn_no_return = true
|
||||
warn_unreachable = true
|
||||
strict_equality = true
|
628
scripts/auto_fix.py
Executable file
628
scripts/auto_fix.py
Executable file
|
@ -0,0 +1,628 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
this script fixes common quality issues including:
|
||||
- code formatting (black, isort, autopep8, autoflake)
|
||||
- markdown formatting (line length, trailing whitespace)
|
||||
- markdown link validation (internal and external)
|
||||
- python code quality issues
|
||||
- import organization
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class AutoFixer:
|
||||
|
||||
def __init__(self):
|
||||
self.venv_path = "venv"
|
||||
self.fixes_applied = 0
|
||||
self.errors_encountered = 0
|
||||
self.link_report = {
|
||||
"total_links": 0,
|
||||
"internal_links": 0,
|
||||
"external_links": 0,
|
||||
"broken_internal": 0,
|
||||
"broken_external": 0,
|
||||
"broken_links": [],
|
||||
}
|
||||
|
||||
def fix_python_code(self) -> bool:
|
||||
print("\n🐍 fixing python code...")
|
||||
python_files = list(Path(".").rglob("*.py"))
|
||||
python_files = [
|
||||
f
|
||||
for f in python_files
|
||||
if not any(
|
||||
part.startswith(".") or part in ["venv", "__pycache__", ".venv"]
|
||||
for part in f.parts
|
||||
)
|
||||
]
|
||||
|
||||
if not python_files:
|
||||
print("ℹ️ no python files found to fix")
|
||||
return True
|
||||
|
||||
print(f"ℹ️ found {len(python_files)} Python files to fix")
|
||||
|
||||
print("🔧 autoflake - removing unused imports...")
|
||||
if self._run_autoflake(python_files):
|
||||
self.fixes_applied += 1
|
||||
print("✅ autoflake completed")
|
||||
else:
|
||||
print("⚠️ autoflake had issues")
|
||||
|
||||
print("🔧 autopep8 - fixing code style...")
|
||||
if self._run_autopep8(python_files):
|
||||
self.fixes_applied += 1
|
||||
print("✅ autopep8 completed")
|
||||
else:
|
||||
print("⚠️ autopep8 had issues")
|
||||
|
||||
print("🔧 isort - organizing imports...")
|
||||
if self._run_isort(python_files):
|
||||
self.fixes_applied += 1
|
||||
print("✅ isort completed")
|
||||
else:
|
||||
print("⚠️ isort had issues")
|
||||
|
||||
print("🔧 black - applying consistent formatting...")
|
||||
if self._run_black(python_files):
|
||||
self.fixes_applied += 1
|
||||
print("✅ black completed")
|
||||
else:
|
||||
print("⚠️ black had issues")
|
||||
return True
|
||||
|
||||
def _run_autoflake(self, python_files: list[Path]) -> bool:
|
||||
try:
|
||||
for file_path in python_files:
|
||||
cmd = [
|
||||
f"{self.venv_path}/bin/autoflake",
|
||||
"--in-place",
|
||||
"--remove-all-unused-imports",
|
||||
"--remove-unused-variables",
|
||||
str(file_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f"autoflake warning for {file_path}: {result.stderr}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"autoflake error: {e}")
|
||||
|
||||
def _run_autopep8(self, python_files: list[Path]) -> bool:
|
||||
try:
|
||||
for file_path in python_files:
|
||||
cmd = [
|
||||
f"{self.venv_path}/bin/autopep8",
|
||||
"--in-place",
|
||||
"--aggressive",
|
||||
"--aggressive",
|
||||
str(file_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f"autopep8 warning for {file_path}: {result.stderr}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"autopep8 error: {e}")
|
||||
|
||||
def _run_isort(self, python_files: list[Path]) -> bool:
|
||||
try:
|
||||
for file_path in python_files:
|
||||
cmd = [f"{self.venv_path}/bin/isort", str(file_path)]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f"isort warning for {file_path}: {result.stderr}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"isort error: {e}")
|
||||
|
||||
def _run_black(self, python_files: list[Path]) -> bool:
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
for file_path in python_files:
|
||||
cmd = [f"{self.venv_path}/bin/black", str(file_path)]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f"black warning for {file_path}: {result.stderr}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"black error: {e}")
|
||||
|
||||
def fix_markdown_files(self) -> bool:
|
||||
print("\n📝 fixing markdown files...")
|
||||
markdown_files = list(Path(".").rglob("*.md"))
|
||||
markdown_files = [
|
||||
f
|
||||
for f in markdown_files
|
||||
if not any(
|
||||
part.startswith(".") or part in ["venv", "__pycache__", ".venv"]
|
||||
for part in f.parts
|
||||
)
|
||||
]
|
||||
|
||||
if not markdown_files:
|
||||
print("ℹ️ no markdown files found to fix")
|
||||
return True
|
||||
|
||||
print(f"ℹ️ found {len(markdown_files)} markdown files to fix")
|
||||
|
||||
print("🔗 checking markdown links...")
|
||||
self.check_markdown_links(markdown_files)
|
||||
self.fix_common_link_issues(markdown_files)
|
||||
|
||||
for file_path in markdown_files:
|
||||
if self.fix_single_markdown_file(file_path):
|
||||
self.fixes_applied += 1
|
||||
return True
|
||||
|
||||
def check_markdown_links(self, markdown_files: list[Path]) -> None:
|
||||
all_links = []
|
||||
|
||||
for file_path in markdown_files:
|
||||
links = self.extract_links_from_markdown(file_path)
|
||||
for link in links:
|
||||
link["source_file"] = file_path
|
||||
all_links.append(link)
|
||||
|
||||
if not all_links:
|
||||
print("ℹ️ no links found in markdown files")
|
||||
return
|
||||
|
||||
self.link_report["total_links"] = len(all_links)
|
||||
print(f"ℹ️ found {len(all_links)} links to check")
|
||||
|
||||
internal_links = [link for link in all_links if not link["is_external"]]
|
||||
self.link_report["internal_links"] = len(internal_links)
|
||||
if internal_links:
|
||||
print(f"🔍 checking {len(internal_links)} internal links...")
|
||||
self.check_internal_links(internal_links)
|
||||
|
||||
external_links = [link for link in all_links if link["is_external"]]
|
||||
self.link_report["external_links"] = len(external_links)
|
||||
if external_links:
|
||||
print(f"🌐 checking {len(external_links)} external links...")
|
||||
self.check_external_links(external_links)
|
||||
|
||||
def extract_links_from_markdown(self, file_path: Path) -> list[dict]:
|
||||
links = []
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
link_pattern = r"\[([^\]]+)\]\(([^)]+)\)"
|
||||
matches = re.findall(link_pattern, content)
|
||||
|
||||
for text, url in matches:
|
||||
is_external = url.startswith(("http://", "https://", "mailto:"))
|
||||
|
||||
links.append(
|
||||
{
|
||||
"text": text.strip(),
|
||||
"url": url.strip(),
|
||||
"is_external": is_external,
|
||||
"line_number": self.get_line_number_for_link(
|
||||
content, text, url
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ error reading {file_path}: {e}")
|
||||
|
||||
return links
|
||||
|
||||
def get_line_number_for_link(self, content: str, text: str, url: str) -> int:
|
||||
lines = content.split("\n")
|
||||
for i, line in enumerate(lines, 1):
|
||||
if f"[{text}]({url})" in line:
|
||||
return i
|
||||
return 0
|
||||
|
||||
def check_internal_links(self, internal_links: list[dict]) -> None:
|
||||
broken_links = []
|
||||
|
||||
for link in internal_links:
|
||||
source_file = link["source_file"]
|
||||
url = link["url"]
|
||||
text = link["text"]
|
||||
line_num = link["line_number"]
|
||||
|
||||
if url.startswith("../"):
|
||||
target_path = source_file.parent.parent / url[3:]
|
||||
elif url.startswith("./"):
|
||||
target_path = source_file.parent / url[2:]
|
||||
else:
|
||||
target_path = source_file.parent / url
|
||||
|
||||
if not target_path.exists():
|
||||
broken_link_info = {
|
||||
"source_file": source_file,
|
||||
"text": text,
|
||||
"url": url,
|
||||
"line_number": line_num,
|
||||
"target_path": target_path,
|
||||
"issue": "File not found",
|
||||
"type": "internal",
|
||||
}
|
||||
broken_links.append(broken_link_info)
|
||||
self.link_report["broken_links"].append(broken_link_info)
|
||||
|
||||
self.link_report["broken_internal"] = len(broken_links)
|
||||
if broken_links:
|
||||
print(f"❌ found {len(broken_links)} broken internal links:")
|
||||
for link in broken_links:
|
||||
print(
|
||||
f" - {link['source_file']}:{link['line_number']} - '{link['text']}' -> {link['url']} ({link['issue']})"
|
||||
)
|
||||
else:
|
||||
print("✅ all internal links are valid")
|
||||
|
||||
def check_external_links(self, external_links: list[dict]) -> None:
|
||||
broken_links = []
|
||||
checked_count = 0
|
||||
|
||||
for link in external_links:
|
||||
url = link["url"]
|
||||
text = link["text"]
|
||||
source_file = link["source_file"]
|
||||
line_num = link["line_number"]
|
||||
|
||||
try:
|
||||
time.sleep(0.1)
|
||||
|
||||
response = requests.head(url, timeout=10, allow_redirects=True)
|
||||
checked_count += 1
|
||||
|
||||
if response.status_code >= 400:
|
||||
if (
|
||||
response.status_code == 429
|
||||
or response.status_code == 403
|
||||
or response.status_code == 443
|
||||
):
|
||||
if checked_count % 10 == 0:
|
||||
print(
|
||||
f" checked {checked_count}/{len(external_links)} external links..."
|
||||
)
|
||||
continue
|
||||
|
||||
broken_link_info = {
|
||||
"source_file": source_file,
|
||||
"text": text,
|
||||
"url": url,
|
||||
"line_number": line_num,
|
||||
"status_code": response.status_code,
|
||||
"issue": f"HTTP {response.status_code}",
|
||||
"type": "external",
|
||||
}
|
||||
broken_links.append(broken_link_info)
|
||||
self.link_report["broken_links"].append(broken_link_info)
|
||||
|
||||
if checked_count % 10 == 0:
|
||||
print(
|
||||
f" checked {checked_count}/{len(external_links)} external links..."
|
||||
)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
broken_link_info = {
|
||||
"source_file": source_file,
|
||||
"text": text,
|
||||
"url": url,
|
||||
"line_number": line_num,
|
||||
"issue": f"Connection error: {str(e)}",
|
||||
"type": "external",
|
||||
}
|
||||
broken_links.append(broken_link_info)
|
||||
self.link_report["broken_links"].append(broken_link_info)
|
||||
except Exception as e:
|
||||
broken_link_info = {
|
||||
"source_file": source_file,
|
||||
"text": text,
|
||||
"url": url,
|
||||
"line_number": line_num,
|
||||
"issue": f"Error: {str(e)}",
|
||||
"type": "external",
|
||||
}
|
||||
broken_links.append(broken_link_info)
|
||||
self.link_report["broken_links"].append(broken_link_info)
|
||||
|
||||
self.link_report["broken_external"] = len(broken_links)
|
||||
if broken_links:
|
||||
print(f"❌ found {len(broken_links)} broken external links:")
|
||||
for link in broken_links:
|
||||
print(
|
||||
f" - {link['source_file']}:{link['line_number']} - '{link['text']}' -> {link['url']} ({link['issue']})"
|
||||
)
|
||||
else:
|
||||
print("✅ all external links are accessible")
|
||||
|
||||
print(f"ℹ️ checked {checked_count} external links")
|
||||
|
||||
def fix_common_link_issues(self, markdown_files: list[Path]) -> None:
|
||||
print("🔧 fixing common link issues...")
|
||||
fixed_count = 0
|
||||
|
||||
for file_path in markdown_files:
|
||||
if self.fix_links_in_file(file_path):
|
||||
fixed_count += 1
|
||||
|
||||
if fixed_count > 0:
|
||||
print(f"✅ fixed links in {fixed_count} files")
|
||||
self.fixes_applied += 1
|
||||
else:
|
||||
print("ℹ️ no link issues found to fix")
|
||||
|
||||
def fix_links_in_file(self, file_path: Path) -> bool:
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
changes_made = False
|
||||
eth_pattern = r"([a-zA-Z0-9]+\.eth)"
|
||||
if re.search(eth_pattern, content):
|
||||
content = re.sub(eth_pattern, r"\1".replace(".eth", ""), content)
|
||||
changes_made = True
|
||||
|
||||
double_space_pattern = r"\[([^\]]+)\]\(([^)]+)\)"
|
||||
|
||||
def fix_spaces(match):
|
||||
text = match.group(1).strip()
|
||||
url = match.group(2).strip()
|
||||
if text != match.group(1) or url != match.group(2):
|
||||
return f"[{text}]({url})"
|
||||
return match.group(0)
|
||||
|
||||
new_content = re.sub(double_space_pattern, fix_spaces, content)
|
||||
if new_content != content:
|
||||
content = new_content
|
||||
changes_made = True
|
||||
|
||||
if changes_made:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
print(f" 🔧 fixed links in {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ error fixing links in {file_path}: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def fix_single_markdown_file(self, file_path: Path) -> bool:
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
lines.copy()
|
||||
fixed_lines = []
|
||||
changes_made = False
|
||||
in_code_block = False
|
||||
|
||||
for _, line in enumerate(lines):
|
||||
line = line.rstrip()
|
||||
|
||||
if not line.strip():
|
||||
fixed_lines.append("\n")
|
||||
continue
|
||||
|
||||
if line.startswith("```"):
|
||||
in_code_block = not in_code_block
|
||||
fixed_lines.append(line + "\n")
|
||||
continue
|
||||
|
||||
if in_code_block:
|
||||
fixed_lines.append(line + "\n")
|
||||
continue
|
||||
|
||||
if line.strip().startswith(("- ", "* ", "+ ", "1. ")):
|
||||
if len(line) > 120:
|
||||
broken_line = self._break_list_item(line)
|
||||
if broken_line != line:
|
||||
changes_made = True
|
||||
print(f" breaking long list item in {file_path}")
|
||||
fixed_lines.append(broken_line + "\n")
|
||||
else:
|
||||
fixed_lines.append(line + "\n")
|
||||
continue
|
||||
|
||||
if len(line) > 120:
|
||||
broken_lines = self.break_long_line(line)
|
||||
if broken_lines != line:
|
||||
changes_made = True
|
||||
print(
|
||||
f" breaking long line in {file_path}: {len(line)} chars -> {len(broken_lines.split(chr(10))[0])} chars"
|
||||
)
|
||||
|
||||
for broken_line in broken_lines.split("\n"):
|
||||
if broken_line.strip():
|
||||
fixed_lines.append(broken_line + "\n")
|
||||
else:
|
||||
fixed_lines.append("\n")
|
||||
else:
|
||||
fixed_lines.append(line + "\n")
|
||||
|
||||
if changes_made:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.writelines(fixed_lines)
|
||||
print(f" ✅ fixed {file_path}")
|
||||
return True
|
||||
else:
|
||||
print(f" ℹ️ no changes needed in {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ error fixing {file_path}: {e}")
|
||||
self.errors_encountered += 1
|
||||
|
||||
def break_long_line(self, line: str) -> str:
|
||||
if len(line) <= 120:
|
||||
return line
|
||||
|
||||
if ". " in line:
|
||||
parts = line.split(". ")
|
||||
if len(parts[0]) <= 120:
|
||||
remaining = ". ".join(parts[1:])
|
||||
if len(remaining) <= 120:
|
||||
return parts[0] + ". " + remaining
|
||||
else:
|
||||
broken_remaining = self._break_at_words(remaining)
|
||||
return parts[0] + ".\n" + broken_remaining
|
||||
return self._break_at_words(line)
|
||||
|
||||
def _break_at_words(self, line: str) -> str:
|
||||
words = line.split()
|
||||
result = []
|
||||
current_line = ""
|
||||
|
||||
for word in words:
|
||||
if current_line and len(current_line + " " + word) > 120:
|
||||
if current_line:
|
||||
result.append(current_line)
|
||||
current_line = word
|
||||
else:
|
||||
if current_line:
|
||||
current_line += " " + word
|
||||
else:
|
||||
current_line = word
|
||||
|
||||
if current_line:
|
||||
result.append(current_line)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def _break_list_item(self, line: str) -> str:
|
||||
marker_end = 0
|
||||
for i, char in enumerate(line):
|
||||
if char in "-*+" or (char.isdigit() and line[i + 1 : i + 3] == ". "):
|
||||
marker_end = line.find(" ", i)
|
||||
if marker_end == -1:
|
||||
marker_end = len(line)
|
||||
break
|
||||
|
||||
if marker_end == 0:
|
||||
return self.break_long_line(line)
|
||||
|
||||
marker = line[: marker_end + 1]
|
||||
content = line[marker_end + 1 :]
|
||||
|
||||
if len(content) <= 120 - len(marker):
|
||||
return line
|
||||
|
||||
broken_content = self._break_at_words(content)
|
||||
if "\n" in broken_content:
|
||||
indent = " " * len(marker)
|
||||
lines = broken_content.split("\n")
|
||||
result = [marker + lines[0]]
|
||||
for continuation_line in lines[1:]:
|
||||
result.append(indent + continuation_line)
|
||||
return "\n".join(result)
|
||||
|
||||
return line
|
||||
|
||||
def fix_trailing_whitespace(self) -> bool:
|
||||
print("\n🧹 fixing trailing whitespace...")
|
||||
|
||||
text_extensions = {".py", ".md", ".txt", ".rst", ".yml", ".yaml", ".json"}
|
||||
files_fixed = 0
|
||||
|
||||
for root, dirs, files in os.walk("."):
|
||||
dirs[:] = [
|
||||
d
|
||||
for d in dirs
|
||||
if not d.startswith(".")
|
||||
and d not in ["venv", "__pycache__", ".venv", "node_modules"]
|
||||
]
|
||||
|
||||
for file in files:
|
||||
if any(file.endswith(ext) for ext in text_extensions):
|
||||
file_path = os.path.join(root, file)
|
||||
if self.fix_file_trailing_whitespace(file_path):
|
||||
files_fixed += 1
|
||||
|
||||
if files_fixed > 0:
|
||||
self.fixes_applied += 1
|
||||
print(f"ℹ️ fixed trailing whitespace in {files_fixed} files")
|
||||
return True
|
||||
|
||||
def fix_file_trailing_whitespace(self, file_path: str) -> bool:
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
original_lines = lines.copy()
|
||||
fixed_lines = []
|
||||
|
||||
for line in lines:
|
||||
fixed_line = line.rstrip() + "\n"
|
||||
fixed_lines.append(fixed_line)
|
||||
|
||||
if fixed_lines != original_lines:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.writelines(fixed_lines)
|
||||
print(f"ℹ️ fixed trailing whitespace in {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"warning: could not fix {file_path}: {e}")
|
||||
|
||||
def run_all_fixes(self) -> bool:
|
||||
print("🚀 starting auto-fix process...")
|
||||
success = True
|
||||
success &= self.fix_trailing_whitespace()
|
||||
success &= self.fix_python_code()
|
||||
success &= self.fix_markdown_files()
|
||||
return success
|
||||
|
||||
def print_summary(self):
|
||||
print("\n" + "=" * 50)
|
||||
print("🎯 AUTO-FIX SUMMARY")
|
||||
print("=" * 50)
|
||||
print(f"✅ fixes applied: {self.fixes_applied}")
|
||||
print(f"❌ errors encountered: {self.errors_encountered}")
|
||||
|
||||
if self.link_report["total_links"] > 0:
|
||||
print("\n🔗 LINK CHECK SUMMARY")
|
||||
print("-" * 30)
|
||||
print(f"📊 total links found: {self.link_report['total_links']}")
|
||||
print(f"🔍 internal links: {self.link_report['internal_links']}")
|
||||
print(f"🌐 external links: {self.link_report['external_links']}")
|
||||
print(f"❌ broken internal: {self.link_report['broken_internal']}")
|
||||
print(f"❌ broken external: {self.link_report['broken_external']}")
|
||||
|
||||
if self.link_report["broken_links"]:
|
||||
print(f"\n⚠️ broken links found:")
|
||||
for link in self.link_report["broken_links"]:
|
||||
print(
|
||||
f" - {link['source_file']}:{link['line_number']} - '{link['text']}' -> {link['url']} ({link['issue']})"
|
||||
)
|
||||
|
||||
if (
|
||||
self.errors_encountered == 0
|
||||
and self.link_report["broken_internal"] == 0
|
||||
and self.link_report["broken_external"] == 0
|
||||
):
|
||||
print("\n🎉 all fixes completed successfully and all links are working!")
|
||||
elif self.errors_encountered == 0:
|
||||
print(f"\n✅ all fixes completed successfully!")
|
||||
print(
|
||||
f"⚠️ but {self.link_report['broken_internal'] + self.link_report['broken_external']} broken links were found"
|
||||
)
|
||||
else:
|
||||
print(f"\n⚠️ {self.errors_encountered} errors occurred during fixing")
|
||||
|
||||
|
||||
def main():
|
||||
fixer = AutoFixer()
|
||||
fixer.run_all_fixes()
|
||||
fixer.print_summary()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
10
scripts/requirements.txt
Normal file
10
scripts/requirements.txt
Normal file
|
@ -0,0 +1,10 @@
|
|||
textstat>=0.7.3
|
||||
requests>=2.28.0
|
||||
beautifulsoup4>=4.11.0
|
||||
markdown>=3.4.0
|
||||
black>=23.0.0
|
||||
isort>=5.12.0
|
||||
flake8>=6.0.0
|
||||
mypy>=1.0.0
|
||||
autopep8>=2.0.0
|
||||
autoflake>=2.0.0
|
Loading…
Add table
Add a link
Reference in a new issue