mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-08-16 02:00:32 -04:00
364 lines
12 KiB
Python
364 lines
12 KiB
Python
import os.path as osp
|
|
import random
|
|
|
|
import horovod.tensorflow as hvd
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from baselines.common.tf_util import initialize
|
|
from data import Cifar10, Imagenet, TFImagenetLoader
|
|
from fid import get_fid_score
|
|
from inception import get_inception_score
|
|
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
|
|
from scipy.misc import imsave
|
|
from tensorflow.python.platform import flags
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from utils import optimistic_remap_restore
|
|
|
|
hvd.init()
|
|
|
|
|
|
flags.DEFINE_string(
|
|
"logdir", "cachedir", "location where log of experiments will be stored"
|
|
)
|
|
flags.DEFINE_string("exp", "default", "name of experiments")
|
|
flags.DEFINE_bool("cclass", False, "whether to condition on class")
|
|
|
|
# Architecture settings
|
|
flags.DEFINE_bool("bn", False, "Whether to use batch normalization or not")
|
|
flags.DEFINE_bool("spec_norm", True, "Whether to use spectral normalization on weights")
|
|
flags.DEFINE_bool("use_bias", True, "Whether to use bias in convolution")
|
|
flags.DEFINE_bool("use_attention", False, "Whether to use self attention in network")
|
|
flags.DEFINE_float("step_lr", 10.0, "Size of steps for gradient descent")
|
|
flags.DEFINE_integer("num_steps", 20, "number of steps to optimize the label")
|
|
flags.DEFINE_float("proj_norm", 0.05, "Maximum change of input images")
|
|
flags.DEFINE_integer("batch_size", 512, "batch size")
|
|
flags.DEFINE_integer("resume_iter", -1, "resume iteration")
|
|
flags.DEFINE_integer("ensemble", 10, "number of ensembles")
|
|
flags.DEFINE_integer("im_number", 50000, "number of ensembles")
|
|
flags.DEFINE_integer("repeat_scale", 100, "number of repeat iterations")
|
|
flags.DEFINE_float("noise_scale", 0.005, "amount of noise to output")
|
|
flags.DEFINE_integer("idx", 0, "save index")
|
|
flags.DEFINE_integer("nomix", 10, "number of intervals to stop mixing")
|
|
flags.DEFINE_bool("scaled", True, "whether to scale noise added")
|
|
flags.DEFINE_bool("large_model", False, "whether to use a small or large model")
|
|
flags.DEFINE_bool("larger_model", False, "Whether to use a large model")
|
|
flags.DEFINE_bool("wider_model", False, "Whether to use a large model")
|
|
flags.DEFINE_bool("single", False, "single ")
|
|
flags.DEFINE_string("datasource", "random", "default or noise or negative or single")
|
|
flags.DEFINE_string("dataset", "cifar10", "cifar10 or imagenet or imagenetfull")
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
class InceptionReplayBuffer(object):
|
|
def __init__(self, size):
|
|
"""Create Replay buffer.
|
|
Parameters
|
|
----------
|
|
size: int
|
|
Max number of transitions to store in the buffer. When the buffer
|
|
overflows the old memories are dropped.
|
|
"""
|
|
self._storage = []
|
|
self._label_storage = []
|
|
self._maxsize = size
|
|
self._next_idx = 0
|
|
|
|
def __len__(self):
|
|
return len(self._storage)
|
|
|
|
def add(self, ims, labels):
|
|
batch_size = ims.shape[0]
|
|
if self._next_idx >= len(self._storage):
|
|
self._storage.extend(list(ims))
|
|
self._label_storage.extend(list(labels))
|
|
else:
|
|
if batch_size + self._next_idx < self._maxsize:
|
|
self._storage[self._next_idx : self._next_idx + batch_size] = list(ims)
|
|
self._label_storage[self._next_idx : self._next_idx + batch_size] = (
|
|
list(labels)
|
|
)
|
|
else:
|
|
split_idx = self._maxsize - self._next_idx
|
|
self._storage[self._next_idx :] = list(ims)[:split_idx]
|
|
self._storage[: batch_size - split_idx] = list(ims)[split_idx:]
|
|
self._label_storage[self._next_idx :] = list(labels)[:split_idx]
|
|
self._label_storage[: batch_size - split_idx] = list(labels)[split_idx:]
|
|
|
|
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
|
|
|
|
def _encode_sample(self, idxes):
|
|
ims = []
|
|
labels = []
|
|
for i in idxes:
|
|
ims.append(self._storage[i])
|
|
labels.append(self._label_storage[i])
|
|
return np.array(ims), np.array(labels)
|
|
|
|
def sample(self, batch_size):
|
|
"""Sample a batch of experiences.
|
|
Parameters
|
|
----------
|
|
batch_size: int
|
|
How many transitions to sample.
|
|
Returns
|
|
-------
|
|
obs_batch: np.array
|
|
batch of observations
|
|
act_batch: np.array
|
|
batch of actions executed given obs_batch
|
|
rew_batch: np.array
|
|
rewards received as results of executing act_batch
|
|
next_obs_batch: np.array
|
|
next set of observations seen after executing act_batch
|
|
done_mask: np.array
|
|
done_mask[i] = 1 if executing act_batch[i] resulted in
|
|
the end of an episode and 0 otherwise.
|
|
"""
|
|
idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
|
|
return self._encode_sample(idxes), idxes
|
|
|
|
def set_elms(self, idxes, data, labels):
|
|
for i, ix in enumerate(idxes):
|
|
self._storage[ix] = data[i]
|
|
self._label_storage[ix] = labels[i]
|
|
|
|
|
|
def rescale_im(im):
|
|
return np.clip(im * 256, 0, 255).astype(np.uint8)
|
|
|
|
|
|
def compute_inception(sess, target_vars):
|
|
X_START = target_vars["X_START"]
|
|
Y_GT = target_vars["Y_GT"]
|
|
X_finals = target_vars["X_finals"]
|
|
NOISE_SCALE = target_vars["NOISE_SCALE"]
|
|
energy_noise = target_vars["energy_noise"]
|
|
|
|
size = FLAGS.im_number
|
|
num_steps = size // 1000
|
|
|
|
images = []
|
|
test_ims = []
|
|
|
|
if FLAGS.dataset == "cifar10":
|
|
test_dataset = Cifar10(full=True, noise=False)
|
|
elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
|
|
test_dataset = Imagenet(train=False)
|
|
|
|
if FLAGS.dataset != "imagenetfull":
|
|
test_dataloader = DataLoader(
|
|
test_dataset,
|
|
batch_size=FLAGS.batch_size,
|
|
num_workers=4,
|
|
shuffle=True,
|
|
drop_last=False,
|
|
)
|
|
else:
|
|
test_dataloader = TFImagenetLoader("test", FLAGS.batch_size, 0, 1)
|
|
|
|
for data_corrupt, data, label_gt in tqdm(test_dataloader):
|
|
data = data.numpy()
|
|
test_ims.extend(list(rescale_im(data)))
|
|
|
|
if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000:
|
|
test_ims = test_ims[:60000]
|
|
break
|
|
|
|
# n = min(len(images), len(test_ims))
|
|
print(len(test_ims))
|
|
# fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
|
|
# print("Base FID of score {}".format(fid))
|
|
|
|
if FLAGS.dataset == "cifar10":
|
|
classes = 10
|
|
else:
|
|
classes = 1000
|
|
|
|
if FLAGS.dataset == "imagenetfull":
|
|
n = 128
|
|
else:
|
|
n = 32
|
|
|
|
for j in range(num_steps):
|
|
itr = int(1000 / 500 * FLAGS.repeat_scale)
|
|
data_buffer = InceptionReplayBuffer(1000)
|
|
curr_index = 0
|
|
|
|
identity = np.eye(classes)
|
|
|
|
for i in tqdm(range(itr)):
|
|
model_index = curr_index % len(X_finals)
|
|
x_final = X_finals[model_index]
|
|
|
|
noise_scale = [1]
|
|
if len(data_buffer) < 1000:
|
|
x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
|
|
label = np.random.randint(0, classes, (FLAGS.batch_size))
|
|
label = identity[label]
|
|
x_new = sess.run(
|
|
[x_final], {X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale}
|
|
)[0]
|
|
data_buffer.add(x_new, label)
|
|
else:
|
|
(x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
|
|
keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99
|
|
label_keep_mask = np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9
|
|
label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
|
|
label_corrupt = identity[label_corrupt]
|
|
x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
|
|
|
|
if i < itr - FLAGS.nomix:
|
|
x_init[keep_mask] = x_init_corrupt[keep_mask]
|
|
label[label_keep_mask] = label_corrupt[label_keep_mask]
|
|
# else:
|
|
# noise_scale = [0.7]
|
|
|
|
x_new, e_noise = sess.run(
|
|
[x_final, energy_noise],
|
|
{X_START: x_init, Y_GT: label, NOISE_SCALE: noise_scale},
|
|
)
|
|
data_buffer.set_elms(idx, x_new, label)
|
|
|
|
if FLAGS.im_number != 50000:
|
|
print(np.mean(e_noise), np.std(e_noise))
|
|
|
|
curr_index += 1
|
|
|
|
ims = np.array(data_buffer._storage[:1000])
|
|
ims = rescale_im(ims)
|
|
|
|
images.extend(list(ims))
|
|
|
|
saveim = osp.join("sandbox_cachedir", FLAGS.exp, "test{}.png".format(FLAGS.idx))
|
|
|
|
ims = ims[:100]
|
|
|
|
if FLAGS.dataset != "imagenetfull":
|
|
im_panel = (
|
|
ims.reshape((10, 10, 32, 32, 3))
|
|
.transpose((0, 2, 1, 3, 4))
|
|
.reshape((320, 320, 3))
|
|
)
|
|
else:
|
|
im_panel = (
|
|
ims.reshape((10, 10, 128, 128, 3))
|
|
.transpose((0, 2, 1, 3, 4))
|
|
.reshape((1280, 1280, 3))
|
|
)
|
|
imsave(saveim, im_panel)
|
|
|
|
print("Saved image!!!!")
|
|
splits = max(1, len(images) // 5000)
|
|
score, std = get_inception_score(images, splits=splits)
|
|
print("Inception score of {} with std of {}".format(score, std))
|
|
|
|
# FID score
|
|
# n = min(len(images), len(test_ims))
|
|
fid = get_fid_score(images, test_ims)
|
|
print("FID of score {}".format(fid))
|
|
|
|
|
|
def main(model_list):
|
|
|
|
if FLAGS.dataset == "imagenetfull":
|
|
model = ResNet128(num_filters=64)
|
|
elif FLAGS.large_model:
|
|
model = ResNet32Large(num_filters=128)
|
|
elif FLAGS.larger_model:
|
|
model = ResNet32Larger(num_filters=hidden_dim)
|
|
elif FLAGS.wider_model:
|
|
model = ResNet32Wider(num_filters=256, train=False)
|
|
else:
|
|
model = ResNet32(num_filters=128)
|
|
|
|
# config = tf.ConfigProto()
|
|
sess = tf.InteractiveSession()
|
|
|
|
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
|
|
weights = []
|
|
|
|
for i, model_num in enumerate(model_list):
|
|
weight = model.construct_weights("context_{}".format(i))
|
|
initialize()
|
|
save_file = osp.join(logdir, "model_{}".format(model_num))
|
|
|
|
v_list = tf.get_collection(
|
|
tf.GraphKeys.GLOBAL_VARIABLES, scope="context_{}".format(i)
|
|
)
|
|
v_map = {
|
|
(v.name.replace("context_{}".format(i), "context_0")[:-2]): v
|
|
for v in v_list
|
|
}
|
|
saver = tf.train.Saver(v_map)
|
|
try:
|
|
saver.restore(sess, save_file)
|
|
except BaseException:
|
|
optimistic_remap_restore(sess, save_file, i)
|
|
weights.append(weight)
|
|
|
|
if FLAGS.dataset == "imagenetfull":
|
|
X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
|
|
else:
|
|
X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
|
|
|
|
if FLAGS.dataset == "cifar10":
|
|
Y_GT = tf.placeholder(shape=(None, 10), dtype=tf.float32)
|
|
else:
|
|
Y_GT = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
|
|
|
|
NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32)
|
|
|
|
X_finals = []
|
|
|
|
# Seperate loops
|
|
for weight in weights:
|
|
X = X_START
|
|
|
|
steps = tf.constant(0)
|
|
|
|
def c(i, x):
|
|
return tf.less(i, FLAGS.num_steps)
|
|
|
|
def langevin_step(counter, X):
|
|
scale_rate = 1
|
|
|
|
X = X + tf.random_normal(
|
|
tf.shape(X),
|
|
mean=0.0,
|
|
stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE,
|
|
)
|
|
|
|
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
|
|
x_grad = tf.gradients(energy_noise, [X])[0]
|
|
|
|
if FLAGS.proj_norm != 0.0:
|
|
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
|
|
|
|
X = X - FLAGS.step_lr * x_grad * scale_rate
|
|
X = tf.clip_by_value(X, 0, 1)
|
|
|
|
counter = counter + 1
|
|
|
|
return counter, X
|
|
|
|
steps, X = tf.while_loop(c, langevin_step, (steps, X))
|
|
energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
|
|
X_final = X
|
|
X_finals.append(X_final)
|
|
|
|
target_vars = {}
|
|
target_vars["X_START"] = X_START
|
|
target_vars["Y_GT"] = Y_GT
|
|
target_vars["X_finals"] = X_finals
|
|
target_vars["NOISE_SCALE"] = NOISE_SCALE
|
|
target_vars["energy_noise"] = energy_noise
|
|
|
|
compute_inception(sess, target_vars)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# model_list = [117000, 116700]
|
|
model_list = [FLAGS.resume_iter - 300 * i for i in range(FLAGS.ensemble)]
|
|
main(model_list)
|