ml-ai-agents-py/ebm_sandbox.py
2020-05-10 22:32:26 -07:00

982 lines
35 KiB
Python

import tensorflow as tf
import math
from tqdm import tqdm
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader
import torch
from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, DspritesNet
from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, DSprites
from utils import optimistic_restore, set_seed
import os.path as osp
import numpy as np
from baselines.logger import TensorBoardOutputFormat
from scipy.misc import imsave
import os
import sklearn.metrics as sk
from baselines.common.tf_util import initialize
from scipy.linalg import eig
import matplotlib.pyplot as plt
# set_seed(1)
flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
flags.DEFINE_string('dataset', 'cifar10', 'omniglot or imagenet or omniglotfull or cifar10 or mnist or dsprites')
flags.DEFINE_string('logdir', 'sandbox_cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('task', 'label', 'using conditional energy based models for classification'
'anticorrupt: restore salt and pepper noise),'
' boxcorrupt: restore empty portion of image'
'or crossclass: change images from one class to another'
'or cycleclass: view image change across a label'
'or nearestneighbor which returns the nearest images in the test set'
'or latent to traverse the latent space of an EBM through eigenvectors of the hessian (dsprites only)'
'or mixenergy to evaluate out of distribution generalization compared to other datasets')
flags.DEFINE_bool('hessian', True, 'Whether to use the hessian or the Jacobian for latent traversals')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
flags.DEFINE_integer('batch_size', 32, 'Size of inputs')
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
flags.DEFINE_bool('train', True, 'Whether to train or test network')
flags.DEFINE_bool('single', False, 'whether to use one sample to debug')
flags.DEFINE_bool('cclass', True, 'whether to use a conditional model (required for task label)')
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
flags.DEFINE_float('step_lr', 10.0, 'step size for updates on label')
flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
flags.DEFINE_bool('large_model', False, 'Whether to use a large model')
flags.DEFINE_bool('larger_model', False, 'Whether to use a larger model')
flags.DEFINE_bool('wider_model', False, 'Whether to use a widermodel model')
flags.DEFINE_bool('svhn', False, 'Whether to test on SVHN')
# Conditions for mixenergy (outlier detection)
flags.DEFINE_bool('svhnmix', False, 'Whether to test mix on SVHN')
flags.DEFINE_bool('cifar100mix', False, 'Whether to test mix on CIFAR100')
flags.DEFINE_bool('texturemix', False, 'Whether to test mix on Textures dataset')
flags.DEFINE_bool('randommix', False, 'Whether to test mix on random dataset')
# Conditions for label task (adversarial classification)
flags.DEFINE_integer('lival', 8, 'Value of constraint for li')
flags.DEFINE_integer('l2val', 40, 'Value of constraint for l2')
flags.DEFINE_integer('pgd', 0, 'number of steps project gradient descent to run')
flags.DEFINE_integer('lnorm', -1, 'linfinity is -1, l2 norm is 2')
flags.DEFINE_bool('labelgrid', False, 'Make a grid of labels')
# Conditions on which models to use
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
flags.DEFINE_bool('cond_size', True, 'whether to condition on scale')
FLAGS = flags.FLAGS
def rescale_im(im):
im = np.clip(im, 0, 1)
return np.round(im * 255).astype(np.uint8)
def label(dataloader, test_dataloader, target_vars, sess, l1val=8, l2val=40):
X = target_vars['X']
Y = target_vars['Y']
Y_GT = target_vars['Y_GT']
accuracy = target_vars['accuracy']
train_op = target_vars['train_op']
l1_norm = target_vars['l1_norm']
l2_norm = target_vars['l2_norm']
label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10))
label_init = label_init / label_init.sum(axis=1, keepdims=True)
label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1))
label_init = np.reshape(label_init, (-1, 10))
for i in range(1):
emp_accuracies = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val}
emp_accuracy = sess.run([accuracy], feed_dict)
emp_accuracies.append(emp_accuracy)
print(np.array(emp_accuracies).mean())
print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val))
return np.array(emp_accuracies).mean()
def labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=8, l2val=40):
X = target_vars['X']
Y = target_vars['Y']
Y_GT = target_vars['Y_GT']
accuracy = target_vars['accuracy']
train_op = target_vars['train_op']
l1_norm = target_vars['l1_norm']
l2_norm = target_vars['l2_norm']
label_init = np.random.uniform(0, 1, (FLAGS.batch_size, 10))
label_init = label_init / label_init.sum(axis=1, keepdims=True)
label_init = np.tile(np.eye(10)[None :, :], (FLAGS.batch_size, 1, 1))
label_init = np.reshape(label_init, (-1, 10))
itr = 0
if FLAGS.train:
for i in range(1):
for data_corrupt, data, label_gt in tqdm(dataloader):
feed_dict = {X: data, Y_GT: label_gt, Y: label_init}
acc, _ = sess.run([accuracy, train_op], feed_dict)
itr += 1
if itr % 10 == 0:
print(acc)
saver.save(sess, osp.join(savedir, "model_supervised"))
saver.restore(sess, osp.join(savedir, "model_supervised"))
for i in range(1):
emp_accuracies = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
feed_dict = {X: data, Y_GT: label_gt, Y: label_init, l1_norm: l1val, l2_norm: l2val}
emp_accuracy = sess.run([accuracy], feed_dict)
emp_accuracies.append(emp_accuracy)
print(np.array(emp_accuracies).mean())
print("Received total accuracy of {} for li of {} and l2 of {}".format(np.array(emp_accuracies).mean(), l1val, l2val))
return np.array(emp_accuracies).mean()
def energyeval(dataloader, test_dataloader, target_vars, sess):
X = target_vars['X']
Y_GT = target_vars['Y_GT']
energy = target_vars['energy']
energy_end = target_vars['energy_end']
test_energies = []
train_energies = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
feed_dict = {X: data, Y_GT: label_gt}
test_energy = sess.run([energy], feed_dict)[0]
test_energies.extend(list(test_energy))
for data_corrupt, data, label_gt in tqdm(dataloader):
feed_dict = {X: data, Y_GT: label_gt}
train_energy = sess.run([energy], feed_dict)[0]
train_energies.extend(list(train_energy))
print(len(train_energies))
print(len(test_energies))
print("Train energies of {} with std {}".format(np.mean(train_energies), np.std(train_energies)))
print("Test energies of {} with std {}".format(np.mean(test_energies), np.std(test_energies)))
np.save("train_ebm.npy", train_energies)
np.save("test_ebm.npy", test_energies)
def energyevalmix(dataloader, test_dataloader, target_vars, sess):
X = target_vars['X']
Y_GT = target_vars['Y_GT']
energy = target_vars['energy']
if FLAGS.svhnmix:
dataset = Svhn(train=False)
test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
test_iter = iter(test_dataloader_val)
elif FLAGS.cifar100mix:
dataset = Cifar100(train=False)
test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
test_iter = iter(test_dataloader_val)
elif FLAGS.texturemix:
dataset = Textures()
test_dataloader_val = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)
test_iter = iter(test_dataloader_val)
probs = []
labels = []
negs = []
pos = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
data = data.numpy()
data_corrupt = data_corrupt.numpy()
if FLAGS.svhnmix:
_, data_mix, _ = test_iter.next()
elif FLAGS.cifar100mix:
_, data_mix, _ = test_iter.next()
elif FLAGS.texturemix:
_, data_mix, _ = test_iter.next()
elif FLAGS.randommix:
data_mix = np.random.randn(FLAGS.batch_size, 32, 32, 3) * 0.5 + 0.5
else:
data_idx = np.concatenate([np.arange(1, data.shape[0]), [0]])
data_other = data[data_idx]
data_mix = (data + data_other) / 2
data_mix = data_mix[:data.shape[0]]
if FLAGS.cclass:
# It's unfair to take a random class
label_gt= np.tile(np.eye(10), (data.shape[0], 1, 1))
label_gt = label_gt.reshape(data.shape[0] * 10, 10)
data_mix = np.tile(data_mix[:, None, :, :, :], (1, 10, 1, 1, 1))
data = np.tile(data[:, None, :, :, :], (1, 10, 1, 1, 1))
data_mix = data_mix.reshape(-1, 32, 32, 3)
data = data.reshape(-1, 32, 32, 3)
feed_dict = {X: data, Y_GT: label_gt}
feed_dict_neg = {X: data_mix, Y_GT: label_gt}
pos_energy = sess.run([energy], feed_dict)[0]
neg_energy = sess.run([energy], feed_dict_neg)[0]
if FLAGS.cclass:
pos_energy = pos_energy.reshape(-1, 10).min(axis=1)
neg_energy = neg_energy.reshape(-1, 10).min(axis=1)
probs.extend(list(-1*pos_energy))
probs.extend(list(-1*neg_energy))
pos.extend(list(-1*pos_energy))
negs.extend(list(-1*neg_energy))
labels.extend([1]*pos_energy.shape[0])
labels.extend([0]*neg_energy.shape[0])
pos, negs = np.array(pos), np.array(negs)
np.save("pos.npy", pos)
np.save("neg.npy", negs)
auroc = sk.roc_auc_score(labels, probs)
print("Roc score of {}".format(auroc))
def anticorrupt(dataloader, weights, model, target_vars, logdir, sess):
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
noise = np.random.uniform(0, 1, size=[data.shape[0], data.shape[1], data.shape[2]])
low_mask = noise < 0.05
high_mask = (noise > 0.05) & (noise < 0.1)
print(high_mask.shape)
data_corrupt = data.copy()
data_corrupt[low_mask] = 0.1
data_corrupt[high_mask] = 0.9
data_corrupt_init = data_corrupt
for i in range(5):
feed_dict = {X: data_corrupt, Y_GT: label_gt}
data_corrupt = sess.run([X_final], feed_dict)[0]
data_uncorrupt = data_corrupt
data_corrupt, data_uncorrupt, data = rescale_im(data_corrupt_init), rescale_im(data_uncorrupt), rescale_im(data)
panel_im = np.zeros((32*20, 32*3, 3)).astype(np.uint8)
for i in range(20):
panel_im[32*i:32*i+32, :32] = data_corrupt[i]
panel_im[32*i:32*i+32, 32:64] = data_uncorrupt[i]
panel_im[32*i:32*i+32, 64:] = data[i]
imsave(osp.join(logdir, "anticorrupt.png"), panel_im)
assert False
def boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess):
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
eval_im = 10000
data_diff = []
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_uncorrupts = []
data_corrupt = data.copy()
data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3))
data_corrupt_init = data_corrupt
for j in range(10):
feed_dict = {X: data_corrupt, Y_GT: label_gt}
data_corrupt = sess.run([X_final], feed_dict)[0]
val = np.mean(np.square(data_corrupt - data), axis=(1, 2, 3))
data_diff.extend(list(val))
if len(data_diff) > eval_im:
break
print("Mean {} and std {} for train dataloader".format(np.mean(data_diff), np.std(data_diff)))
np.save("data_diff_train_image.npy", data_diff)
data_diff = []
for data_corrupt, data, label_gt in tqdm(test_dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_uncorrupts = []
data_corrupt = data.copy()
data_corrupt[:, 16:, :] = np.random.uniform(0, 1, (FLAGS.batch_size, 16, 32, 3))
data_corrupt_init = data_corrupt
for j in range(10):
feed_dict = {X: data_corrupt, Y_GT: label_gt}
data_corrupt = sess.run([X_final], feed_dict)[0]
data_diff.extend(list(np.mean(np.square(data_corrupt - data), axis=(1, 2, 3))))
if len(data_diff) > eval_im:
break
print("Mean {} and std {} for test dataloader".format(np.mean(data_diff), np.std(data_diff)))
np.save("data_diff_test_image.npy", data_diff)
def crossclass(dataloader, weights, model, target_vars, logdir, sess):
X, Y_GT, X_mods, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_mods'], target_vars['X_final']
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_corrupt = data.copy()
data_corrupt[1:] = data_corrupt[0:-1]
data_corrupt[0] = data[-1]
data_mods = []
data_mod = data_corrupt
for i in range(10):
data_mods.append(data_mod)
feed_dict = {X: data_mod, Y_GT: label_gt}
data_mod = sess.run(X_final, feed_dict)
data_corrupt, data = rescale_im(data_corrupt), rescale_im(data)
data_mods = [rescale_im(data_mod) for data_mod in data_mods]
panel_im = np.zeros((32*20, 32*(len(data_mods) + 2), 3)).astype(np.uint8)
for i in range(20):
panel_im[32*i:32*i+32, :32] = data_corrupt[i]
for j in range(len(data_mods)):
panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i]
panel_im[32*i:32*i+32, -32:] = data[i]
imsave(osp.join(logdir, "crossclass.png"), panel_im)
assert False
def cycleclass(dataloader, weights, model, target_vars, logdir, sess):
# X, Y_GT, X_final, X_targ = target_vars['X'], target_vars['Y_GT'], target_vars['X_final'], target_vars['X_targ']
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_corrupt = data_corrupt.numpy()
data_mods = []
x_curr = data_corrupt
x_target = np.random.uniform(0, 1, data_corrupt.shape)
# x_target = np.tile(x_target, (1, 32, 32, 1))
for i in range(20):
feed_dict = {X: x_curr, Y_GT: label_gt}
x_curr_new = sess.run(X_final, feed_dict)
x_curr = x_curr_new
data_mods.append(x_curr_new)
if i > 30:
x_target = np.random.uniform(0, 1, data_corrupt.shape)
data_corrupt, data = rescale_im(data_corrupt), rescale_im(data)
data_mods = [rescale_im(data_mod) for data_mod in data_mods]
panel_im = np.zeros((32*100, 32*(len(data_mods) + 2), 3)).astype(np.uint8)
for i in range(100):
panel_im[32*i:32*i+32, :32] = data_corrupt[i]
for j in range(len(data_mods)):
panel_im[32*i:32*i+32, 32*(j+1):32*(j+2)] = data_mods[j][i]
panel_im[32*i:32*i+32, -32:] = data[i]
imsave(osp.join(logdir, "cycleclass.png"), panel_im)
assert False
def democlass(dataloader, weights, model, target_vars, logdir, sess):
X, Y_GT, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_final']
panel_im = np.zeros((5*32, 10*32, 3)).astype(np.uint8)
for i in range(10):
data_corrupt = np.random.uniform(0, 1, (5, 32, 32, 3))
label_gt = np.tile(np.eye(10)[i:i+1], (5, 1))
feed_dict = {X: data_corrupt, Y_GT: label_gt}
x_final = sess.run([X_final], feed_dict)[0]
x_final = rescale_im(x_final)
row = i // 2
col = i % 2
start_idx = col * 32 * 5
row_idx = row * 32
for j in range(5):
panel_im[row_idx:row_idx+32, start_idx+j*32:start_idx+(j+1) * 32] = x_final[j]
imsave(osp.join(logdir, "democlass.png"), panel_im)
def construct_finetune_label(weight, X, Y, Y_GT, model, target_vars):
l1_norm = tf.placeholder(shape=(), dtype=tf.float32)
l2_norm = tf.placeholder(shape=(), dtype=tf.float32)
def compute_logit(X, stop_grad=False, num_steps=0):
batch_size = tf.shape(X)[0]
X = tf.reshape(X, (batch_size, 1, 32, 32, 3))
X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3))
Y_new = tf.reshape(Y, (batch_size*10, 10))
X_min = X - 8 / 255.
X_max = X + 8 / 255.
for i in range(num_steps):
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
energy_noise = model.forward(X, weights, label=Y, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad
X = tf.maximum(tf.minimum(X, X_max), X_min)
energy = model.forward(X, weight, label=Y_new)
energy = -tf.reshape(energy, (batch_size, 10))
if stop_grad:
energy = tf.stop_gradient(energy)
return energy
for i in range(FLAGS.pgd):
if FLAGS.train:
break
print("Constructed loop {} of pgd attack".format(i))
X_init = X
if i == 0:
X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255.
logit = compute_logit(X)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit)
x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.
X = X + 2 * x_grad
if FLAGS.lnorm == -1:
X = tf.maximum(tf.minimum(X, X_max), X_min)
elif FLAGS.lnorm == 2:
X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3])
energy = compute_logit(X, num_steps=0)
logits = energy
labels = tf.argmax(Y_GT, axis=1)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logits)
optimizer = tf.train.AdamOptimizer(1e-3)
train_op = optimizer.minimize(loss)
accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, axis=1), labels)
target_vars['accuracy'] = accuracy
target_vars['train_op'] = train_op
target_vars['l1_norm'] = l1_norm
target_vars['l2_norm'] = l2_norm
def construct_latent(weights, X, Y_GT, model, target_vars):
eps = 0.001
X_init = X[0:1]
def traversals(model, X, weights, Y_GT):
if FLAGS.hessian:
e_pos = model.forward(X, weights, label=Y_GT)
hessian = tf.hessians(e_pos, X)
hessian = tf.reshape(hessian, (1, 64*64, 64*64))[0]
e, v = tf.linalg.eigh(hessian)
else:
latent = model.forward(X, weights, label=Y_GT, return_logit=True)
latents = tf.split(latent, 128, axis=1)
jacobian = [tf.gradients(latent, X)[0] for latent in latents]
jacobian = tf.stack(jacobian, axis=1)
jacobian = tf.reshape(jacobian, (tf.shape(jacobian)[1], tf.shape(jacobian)[1], 64*64))
s, _, v = tf.linalg.svd(jacobian)
return v
var_scale = 1.0
n = 3
xs = []
v = traversals(model, X_init, weights, Y_GT)
for i in range(n):
var = tf.reshape(v[:, i], (1, 64, 64))
X_plus = X_init - var_scale * var
X_min = X_init + var_scale * var
xs.extend([X_plus, X_min])
x_stack = tf.stack(xs, axis=0)
e_pos_hess_modify = model.forward(x_stack, weights, label=Y_GT)
for i in range(20):
x_stack = x_stack + tf.random_normal(tf.shape(x_stack), mean=0.0, stddev=0.005)
e_pos = model.forward(x_stack, weights, label=Y_GT)
x_grad = tf.gradients(e_pos, [x_stack])[0]
x_stack = x_stack - 4*FLAGS.step_lr * x_grad
x_stack = tf.clip_by_value(x_stack, 0, 1)
x_mods = tf.split(X, 6)
eigs = []
for j in range(6):
x_mod = x_mods[j]
v = traversals(model, x_mod, weights, Y_GT)
idx = j // 2
var = tf.reshape(v[:, idx], (1, 64, 64))
if j % 2 == 1:
x_mod = x_mod + var_scale * var
eigs.append(var)
else:
x_mod = x_mod - var_scale * var
eigs.append(-var)
x_mod = tf.clip_by_value(x_mod, 0, 1)
x_mods[j] = x_mod
x_mods_stack = tf.stack(x_mods, axis=0)
eigs_stack = tf.stack(eigs, axis=0)
energys = []
for i in range(20):
x_mods_stack = x_mods_stack + tf.random_normal(tf.shape(x_mods_stack), mean=0.0, stddev=0.005)
e_pos = model.forward(x_mods_stack, weights, label=Y_GT)
x_grad = tf.gradients(e_pos, [x_mods_stack])[0]
x_mods_stack = x_mods_stack - 4*FLAGS.step_lr * x_grad
# x_mods_stack = x_mods_stack + 0.1 * eigs_stack
x_mods_stack = tf.clip_by_value(x_mods_stack, 0, 1)
energys.append(e_pos)
x_refine = x_mods_stack
es = tf.stack(energys, axis=0)
# target_vars['hessian'] = hessian
# target_vars['e'] = e
target_vars['v'] = v
target_vars['x_stack'] = x_stack
target_vars['x_refine'] = x_refine
target_vars['es'] = es
# target_vars['e_base'] = e_pos_base
def latent(test_dataloader, weights, model, target_vars, sess):
X = target_vars['X']
Y_GT = target_vars['Y_GT']
# hessian = target_vars['hessian']
# e = target_vars['e']
v = target_vars['v']
x_stack = target_vars['x_stack']
x_refine = target_vars['x_refine']
es = target_vars['es']
# e_pos_base = target_vars['e_base']
# e_pos_hess_modify = target_vars['e_pos_hessian']
data_corrupt, data, label_gt = iter(test_dataloader).next()
data = data.numpy()
x_init = np.tile(data[0:1], (6, 1, 1))
x_mod, = sess.run([x_stack], {X: data})
# print("Value of original starting image: ", e_pos)
# print("Value of energy of hessian: ", e_pos_hess)
x_mod = x_mod.squeeze()
n = 6
x_mod_list = [x_init, x_mod]
for i in range(n):
x_mod, evals = sess.run([x_refine, es], {X: x_mod})
x_mod = x_mod.squeeze()
x_mod_list.append(x_mod)
print("Value of energies after evaluation: ", evals)
x_mod_list = x_mod_list[:]
series_xmod = np.stack(x_mod_list, axis=1)
series_header = np.tile(data[0:1, None, :, :], (1, len(x_mod_list), 1, 1))
series_total = np.concatenate([series_header, series_xmod], axis=0)
series_total_full = np.ones((*series_total.shape[:-2], 66, 66))
series_total_full[:, :, 1:-1, 1:-1] = series_total
series_total = series_total_full
series_total = series_total.transpose((0, 2, 1, 3)).reshape((-1, len(x_mod_list)*66))
im_total = rescale_im(series_total)
imsave("latent_comb.png", im_total)
def construct_label(weights, X, Y, Y_GT, model, target_vars):
# for i in range(FLAGS.num_steps):
# Y = Y + tf.random_normal(tf.shape(Y), mean=0.0, stddev=0.03)
# e = model.forward(X, weights, label=Y)
# Y_grad = tf.clip_by_value(tf.gradients(e, [Y])[0], -1, 1)
# Y = Y - 0.1 * Y_grad
# Y = tf.clip_by_value(Y, 0, 1)
# Y = Y / tf.reduce_sum(Y, axis=[1], keepdims=True)
e_bias = tf.get_variable('e_bias', shape=10, initializer=tf.initializers.zeros())
l1_norm = tf.placeholder(shape=(), dtype=tf.float32)
l2_norm = tf.placeholder(shape=(), dtype=tf.float32)
def compute_logit(X, stop_grad=False, num_steps=0):
batch_size = tf.shape(X)[0]
X = tf.reshape(X, (batch_size, 1, 32, 32, 3))
X = tf.reshape(tf.tile(X, (1, 10, 1, 1, 1)), (batch_size * 10, 32, 32, 3))
Y_new = tf.reshape(Y, (batch_size*10, 10))
X_min = X - 8 / 255.
X_max = X + 8 / 255.
for i in range(num_steps):
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
energy_noise = model.forward(X, weights, label=Y, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad
X = tf.maximum(tf.minimum(X, X_max), X_min)
energy = model.forward(X, weights, label=Y_new)
energy = -tf.reshape(energy, (batch_size, 10))
if stop_grad:
energy = tf.stop_gradient(energy)
return energy
# eps_norm = 30
X_min = X - l1_norm / 255.
X_max = X + l1_norm / 255.
for i in range(FLAGS.pgd):
print("Constructed loop {} of pgd attack".format(i))
X_init = X
if i == 0:
X = X + tf.to_float(tf.random_uniform(tf.shape(X), minval=-8, maxval=9, dtype=tf.int32)) / 255.
logit = compute_logit(X)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=logit)
x_grad = tf.sign(tf.gradients(loss, [X])[0]) / 255.
X = X + 2 * x_grad
if FLAGS.lnorm == -1:
X = tf.maximum(tf.minimum(X, X_max), X_min)
elif FLAGS.lnorm == 2:
X = X_init + tf.clip_by_norm(X - X_init, l2_norm / 255., axes=[1, 2, 3])
energy_stopped = compute_logit(X, stop_grad=True, num_steps=FLAGS.num_steps) + e_bias
# # Y = tf.Print(Y, [Y])
labels = tf.argmax(Y_GT, axis=1)
# max_z = tf.argmax(energy_stopped, axis=1)
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_GT, logits=energy_stopped)
optimizer = tf.train.AdamOptimizer(1e-2)
train_op = optimizer.minimize(loss)
accuracy = tf.contrib.metrics.accuracy(tf.argmax(energy_stopped, axis=1), labels)
target_vars['accuracy'] = accuracy
target_vars['train_op'] = train_op
target_vars['l1_norm'] = l1_norm
target_vars['l2_norm'] = l2_norm
def construct_energy(weights, X, Y, Y_GT, model, target_vars):
energy = model.forward(X, weights, label=Y_GT)
for i in range(FLAGS.num_steps):
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005)
energy_noise = model.forward(X, weights, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad
X = tf.clip_by_value(X, 0, 1)
target_vars['energy'] = energy
target_vars['energy_end'] = energy_noise
def construct_steps(weights, X, Y_GT, model, target_vars):
n = 50
scale_fac = 1.0
# if FLAGS.task == 'cycleclass':
# scale_fac = 10.0
X_mods = []
X = tf.identity(X)
mask = np.zeros((1, 32, 32, 3))
if FLAGS.task == "boxcorrupt":
mask[:, 16:, :, :] = 1
else:
mask[:, :, :, :] = 1
mask = tf.Variable(tf.convert_to_tensor(mask, dtype=tf.float32), trainable=False)
# X_targ = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
for i in range(FLAGS.num_steps):
X_old = X
X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=0.005*scale_fac) * mask
energy_noise = model.forward(X, weights, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
if FLAGS.proj_norm != 0.0:
x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
X = X - FLAGS.step_lr * x_grad * scale_fac * mask
X = tf.clip_by_value(X, 0, 1)
if i % n == (n-1):
X_mods.append(X)
print("Constructing step {}".format(i))
target_vars['X_final'] = X
target_vars['X_mods'] = X_mods
def nearest_neighbor(dataset, sess, target_vars, logdir):
X = target_vars['X']
Y_GT = target_vars['Y_GT']
x_final = target_vars['X_final']
noise = np.random.uniform(0, 1, size=[10, 32, 32, 3])
# label = np.random.randint(0, 10, size=[10])
label = np.eye(10)
coarse = noise
for i in range(10):
x_new = sess.run([x_final], {X:coarse, Y_GT:label})[0]
coarse = x_new
x_new_dense = x_new.reshape(10, 1, 32*32*3)
dataset_dense = dataset.reshape(1, 50000, 32*32*3)
diff = np.square(x_new_dense - dataset_dense).sum(axis=2)
diff_idx = np.argsort(diff, axis=1)
panel = np.zeros((32*10, 32*6, 3))
dataset_rescale = rescale_im(dataset)
x_new_rescale = rescale_im(x_new)
for i in range(10):
panel[i*32:i*32+32, :32] = x_new_rescale[i]
for j in range(5):
panel[i*32:i*32+32, 32*j+32:32*j+64] = dataset_rescale[diff_idx[i, j]]
imsave(osp.join(logdir, "nearest.png"), panel)
def main():
if FLAGS.dataset == "cifar10":
dataset = Cifar10(train=True, noise=False)
test_dataset = Cifar10(train=False, noise=False)
else:
dataset = Imagenet(train=True)
test_dataset = Imagenet(train=False)
if FLAGS.svhn:
dataset = Svhn(train=True)
test_dataset = Svhn(train=False)
if FLAGS.task == 'latent':
dataset = DSprites()
test_dataset = dataset
dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True)
hidden_dim = 128
if FLAGS.large_model:
model = ResNet32Large(num_filters=hidden_dim)
elif FLAGS.larger_model:
model = ResNet32Larger(num_filters=hidden_dim)
elif FLAGS.wider_model:
if FLAGS.dataset == 'imagenet':
model = ResNet32Wider(num_filters=196, train=False)
else:
model = ResNet32Wider(num_filters=256, train=False)
else:
model = ResNet32(num_filters=hidden_dim)
if FLAGS.task == 'latent':
model = DspritesNet()
weights = model.construct_weights('context_{}'.format(0))
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
print("Model has a total of {} parameters".format(total_parameters))
config = tf.ConfigProto()
sess = tf.InteractiveSession()
if FLAGS.task == 'latent':
X = tf.placeholder(shape=(None, 64, 64), dtype = tf.float32)
else:
X = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
if FLAGS.dataset == "cifar10":
Y = tf.placeholder(shape=(None, 10), dtype = tf.float32)
Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
elif FLAGS.dataset == "imagenet":
Y = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT}
if FLAGS.task == 'label':
construct_label(weights, X, Y, Y_GT, model, target_vars)
elif FLAGS.task == 'labelfinetune':
construct_finetune_label(weights, X, Y, Y_GT, model, target_vars, )
elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy':
construct_energy(weights, X, Y, Y_GT, model, target_vars)
elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor':
construct_steps(weights, X, Y_GT, model, target_vars)
elif FLAGS.task == 'latent':
construct_latent(weights, X, Y_GT, model, target_vars)
sess.run(tf.global_variables_initializer())
saver = loader = tf.train.Saver(max_to_keep=10)
savedir = osp.join('cachedir', FLAGS.exp)
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
if not osp.exists(logdir):
os.makedirs(logdir)
initialize()
if FLAGS.resume_iter != -1:
model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter))
resume_itr = FLAGS.resume_iter
if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy":
optimistic_restore(sess, model_file)
# saver.restore(sess, model_file)
else:
# optimistic_restore(sess, model_file)
saver.restore(sess, model_file)
if FLAGS.task == 'label':
if FLAGS.labelgrid:
vals = []
if FLAGS.lnorm == -1:
for i in range(31):
accuracies = label(dataloader, test_dataloader, target_vars, sess, l1val=i)
vals.append(accuracies)
elif FLAGS.lnorm == 2:
for i in range(0, 100, 5):
accuracies = label(dataloader, test_dataloader, target_vars, sess, l2val=i)
vals.append(accuracies)
np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals)
else:
label(dataloader, test_dataloader, target_vars, sess)
elif FLAGS.task == 'labelfinetune':
labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=FLAGS.lival, l2val=FLAGS.l2val)
elif FLAGS.task == 'energyeval':
energyeval(dataloader, test_dataloader, target_vars, sess)
elif FLAGS.task == 'mixenergy':
energyevalmix(dataloader, test_dataloader, target_vars, sess)
elif FLAGS.task == 'anticorrupt':
anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'boxcorrupt':
# boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'crossclass':
crossclass(test_dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'cycleclass':
cycleclass(test_dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'democlass':
democlass(test_dataloader, weights, model, target_vars, logdir, sess)
elif FLAGS.task == 'nearestneighbor':
# print(dir(dataset))
# print(type(dataset))
nearest_neighbor(dataset.data.train_data / 255, sess, target_vars, logdir)
elif FLAGS.task == 'latent':
latent(test_dataloader, weights, model, target_vars, sess)
if __name__ == "__main__":
main()