2023-02-07 07:58:48 -08:00

547 lines
18 KiB
Python

from tensorflow.python.platform import flags
from tensorflow.contrib.data.python.ops import batching, threadpool
import tensorflow as tf
import json
from torch.utils.data import Dataset
import pickle
import os.path as osp
import os
import numpy as np
import time
from scipy.misc import imread, imresize
from skimage.color import rgb2grey
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
from torchvision import transforms
from imagenet_preprocessing import ImagenetPreprocessor
import torch
import torchvision
FLAGS = flags.FLAGS
ROOT_DIR = "./results"
# Dataset Options
flags.DEFINE_string('dsprites_path',
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
'path to dsprites characters')
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
# Data augmentation options
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
def cutout(mask_color=(0, 0, 0)):
mask_size_half = FLAGS.cutout_mask_size // 2
offset = 1 if FLAGS.cutout_mask_size % 2 == 0 else 0
def _cutout(image):
image = np.asarray(image).copy()
if np.random.random() > FLAGS.cutout_prob:
return image
h, w = image.shape[:2]
if FLAGS.cutout_inside:
cxmin, cxmax = mask_size_half, w + offset - mask_size_half
cymin, cymax = mask_size_half, h + offset - mask_size_half
else:
cxmin, cxmax = 0, w + offset
cymin, cymax = 0, h + offset
cx = np.random.randint(cxmin, cxmax)
cy = np.random.randint(cymin, cymax)
xmin = cx - mask_size_half
ymin = cy - mask_size_half
xmax = xmin + FLAGS.cutout_mask_size
ymax = ymin + FLAGS.cutout_mask_size
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = min(w, xmax)
ymax = min(h, ymax)
image[:, ymin:ymax, xmin:xmax] = np.array(mask_color)[:, None, None]
return image
return _cutout
class TFImagenetLoader(Dataset):
def __init__(self, split, batchsize, idx, num_workers, rescale=1):
IMAGENET_NUM_TRAIN_IMAGES = 1281167
IMAGENET_NUM_VAL_IMAGES = 50000
self.rescale = rescale
if split == "train":
im_length = IMAGENET_NUM_TRAIN_IMAGES
records_to_skip = im_length * idx // num_workers
records_to_read = im_length * (idx + 1) // num_workers - records_to_skip
else:
im_length = IMAGENET_NUM_VAL_IMAGES
self.curr_sample = 0
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
with open(index_path) as f:
metadata = json.load(f)
counts = metadata['record_counts']
if split == 'train':
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
result_records_to_skip = None
files = []
for filename in file_names:
records_in_file = counts[filename]
if records_to_skip >= records_in_file:
records_to_skip -= records_in_file
continue
elif records_to_read > 0:
if result_records_to_skip is None:
# Record the number to skip in the first file
result_records_to_skip = records_to_skip
files.append(filename)
records_to_read -= (records_in_file - records_to_skip)
records_to_skip = 0
else:
break
else:
files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
ds = ds.apply(tf.data.TFRecordDataset)
ds = ds.take(im_length)
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
ds = ds.prefetch(buffer_size=2)
ds_iterator = ds.make_initializable_iterator()
labels, images = ds_iterator.get_next()
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
self.labels = labels
config = tf.ConfigProto(device_count = {'GPU': 0})
sess = tf.Session(config=config)
sess.run(ds_iterator.initializer)
self.im_length = im_length // batchsize
self.sess = sess
def __next__(self):
self.curr_sample += 1
sess = self.sess
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
label, im = sess.run([self.labels, self.images])
im = im * self.rescale
label = np.eye(1000)[label.squeeze() - 1]
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
return im_corrupt, im, label
def __iter__(self):
return self
def __len__(self):
return self.im_length
class CelebA(Dataset):
def __init__(self):
self.path = "/root/data/img_align_celeba"
self.ims = os.listdir(self.path)
self.ims = [osp.join(self.path, im) for im in self.ims]
def __len__(self):
return len(self.ims)
def __getitem__(self, index):
label = 1
if FLAGS.single:
index = 0
path = self.ims[index]
im = imread(path)
im = imresize(im, (32, 32))
image_size = 32
im = im / 255.
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0, 1, size=(image_size, image_size, 3))
return im_corrupt, im, label
class Cifar10(Dataset):
def __init__(
self,
train=True,
full=False,
augment=False,
noise=True,
rescale=1.0):
if augment:
transform_list = [
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
]
if FLAGS.cutout:
transform_list.append(cutout())
transform = transforms.Compose(transform_list)
else:
transform = transforms.ToTensor()
self.full = full
self.data = CIFAR10(
ROOT_DIR,
transform=transform,
train=train,
download=True)
self.test_data = CIFAR10(
ROOT_DIR,
transform=transform,
train=False,
download=True)
self.one_hot_map = np.eye(10)
self.noise = noise
self.rescale = rescale
def __len__(self):
if self.full:
return len(self.data) + len(self.test_data)
else:
return len(self.data)
def __getitem__(self, index):
if not FLAGS.single:
if self.full:
if index >= len(self.data):
im, label = self.test_data[index - len(self.data)]
else:
im, label = self.data[index]
else:
im, label = self.data[index]
else:
im, label = self.data[0]
im = np.transpose(im, (1, 2, 0)).numpy()
image_size = 32
label = self.one_hot_map[label]
im = im * 255 / 256
if self.noise:
im = im * self.rescale + \
np.random.uniform(0, self.rescale * 1 / 256., im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, self.rescale, (image_size, image_size, 3))
return im_corrupt, im, label
class Cifar100(Dataset):
def __init__(self, train=True, augment=False):
if augment:
transform_list = [
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
]
if FLAGS.cutout:
transform_list.append(cutout())
transform = transforms.Compose(transform_list)
else:
transform = transforms.ToTensor()
self.data = CIFAR100(
"/root/cifar100",
transform=transform,
train=train,
download=True)
self.one_hot_map = np.eye(100)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
if not FLAGS.single:
im, label = self.data[index]
else:
im, label = self.data[0]
im = np.transpose(im, (1, 2, 0)).numpy()
image_size = 32
label = self.one_hot_map[label]
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
class Svhn(Dataset):
def __init__(self, train=True, augment=False):
transform = transforms.ToTensor()
self.data = SVHN("/root/svhn", transform=transform, download=True)
self.one_hot_map = np.eye(10)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
if not FLAGS.single:
im, label = self.data[index]
else:
em, label = self.data[0]
im = np.transpose(im, (1, 2, 0)).numpy()
image_size = 32
label = self.one_hot_map[label]
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
class Mnist(Dataset):
def __init__(self, train=True, rescale=1.0):
self.data = MNIST(
"/root/mnist",
transform=transforms.ToTensor(),
download=True, train=train)
self.labels = np.eye(10)
self.rescale = rescale
def __len__(self):
return len(self.data)
def __getitem__(self, index):
im, label = self.data[index]
label = self.labels[label]
im = im.squeeze()
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
# im = im.numpy() / 2 + 0.2
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
im = im * self.rescale
image_size = 28
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
return im_corrupt, im, label
class DSprites(Dataset):
def __init__(
self,
cond_size=False,
cond_shape=False,
cond_pos=False,
cond_rot=False):
dat = np.load(FLAGS.dsprites_path)
if FLAGS.dshape_only:
l = dat['latents_values']
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
self.label = self.label[:, 1:2]
elif FLAGS.dpos_only:
l = dat['latents_values']
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 1] == 1) & (
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (100, 1))
self.label = self.label[:, 4:] + 0.5
elif FLAGS.dsize_only:
l = dat['latents_values']
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
self.label = (self.label[:, 2:3])
elif FLAGS.drot_only:
l = dat['latents_values']
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (100, 1))
self.label = (self.label[:, 3:4])
self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1)
elif FLAGS.dsprites_restrict:
l = dat['latents_values']
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
self.data = dat['imgs'][mask]
self.label = dat['latents_values'][mask]
else:
self.data = dat['imgs']
self.label = dat['latents_values']
if cond_size:
self.label = self.label[:, 2:3]
elif cond_shape:
self.label = self.label[:, 1:2]
elif cond_pos:
self.label = self.label[:, 4:]
elif cond_rot:
self.label = self.label[:, 3:4]
self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1)
else:
self.label = self.label[:, 1:2]
self.identity = np.eye(3)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
im = self.data[index]
image_size = 64
if not (
FLAGS.dpos_only or FLAGS.dsize_only) and (
not FLAGS.cond_size) and (
not FLAGS.cond_pos) and (
not FLAGS.cond_rot) and (
not FLAGS.drot_only):
label = self.identity[self.label[index].astype(
np.int32) - 1].squeeze()
else:
label = self.label[index]
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random':
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
return im_corrupt, im, label
class Imagenet(Dataset):
def __init__(self, train=True, augment=False):
if train:
for i in range(1, 11):
f = pickle.load(
open(
osp.join(
FLAGS.imagenet_path,
'train_data_batch_{}'.format(i)),
'rb'))
if i == 1:
labels = f['labels']
data = f['data']
else:
labels.extend(f['labels'])
data = np.vstack((data, f['data']))
else:
f = pickle.load(
open(
osp.join(
FLAGS.imagenet_path,
'val_data'),
'rb'))
labels = f['labels']
data = f['data']
self.labels = labels
self.data = data
self.one_hot_map = np.eye(1000)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
if not FLAGS.single:
im, label = self.data[index], self.labels[index]
else:
im, label = self.data[0], self.labels[0]
label -= 1
im = im.reshape((3, 32, 32)) / 255
im = im.transpose((1, 2, 0))
image_size = 32
label = self.one_hot_map[label]
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
class Textures(Dataset):
def __init__(self, train=True, augment=False):
self.dataset = ImageFolder("/mnt/nfs/yilundu/data/dtd/images")
def __len__(self):
return 2 * len(self.dataset)
def __getitem__(self, index):
idx = index % (len(self.dataset))
im, label = self.dataset[idx]
im = np.array(im)[:32, :32] / 255
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
return im, im, label