mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-05-03 07:05:24 -04:00
add cool gpt3 examples
This commit is contained in:
parent
0e8871e15e
commit
88d4a6f793
16 changed files with 205 additions and 0 deletions
546
EBMs/data.py
Normal file
546
EBMs/data.py
Normal file
|
@ -0,0 +1,546 @@
|
|||
from tensorflow.python.platform import flags
|
||||
from tensorflow.contrib.data.python.ops import batching, threadpool
|
||||
import tensorflow as tf
|
||||
import json
|
||||
from torch.utils.data import Dataset
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import os
|
||||
import numpy as np
|
||||
import time
|
||||
from scipy.misc import imread, imresize
|
||||
from skimage.color import rgb2grey
|
||||
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
|
||||
from torchvision import transforms
|
||||
from imagenet_preprocessing import ImagenetPreprocessor
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
ROOT_DIR = "./results"
|
||||
|
||||
# Dataset Options
|
||||
flags.DEFINE_string('dsprites_path',
|
||||
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
|
||||
'path to dsprites characters')
|
||||
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
|
||||
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
|
||||
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
|
||||
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
|
||||
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
|
||||
|
||||
|
||||
# Data augmentation options
|
||||
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
|
||||
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
|
||||
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
|
||||
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
|
||||
|
||||
|
||||
def cutout(mask_color=(0, 0, 0)):
|
||||
mask_size_half = FLAGS.cutout_mask_size // 2
|
||||
offset = 1 if FLAGS.cutout_mask_size % 2 == 0 else 0
|
||||
|
||||
def _cutout(image):
|
||||
image = np.asarray(image).copy()
|
||||
|
||||
if np.random.random() > FLAGS.cutout_prob:
|
||||
return image
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if FLAGS.cutout_inside:
|
||||
cxmin, cxmax = mask_size_half, w + offset - mask_size_half
|
||||
cymin, cymax = mask_size_half, h + offset - mask_size_half
|
||||
else:
|
||||
cxmin, cxmax = 0, w + offset
|
||||
cymin, cymax = 0, h + offset
|
||||
|
||||
cx = np.random.randint(cxmin, cxmax)
|
||||
cy = np.random.randint(cymin, cymax)
|
||||
xmin = cx - mask_size_half
|
||||
ymin = cy - mask_size_half
|
||||
xmax = xmin + FLAGS.cutout_mask_size
|
||||
ymax = ymin + FLAGS.cutout_mask_size
|
||||
xmin = max(0, xmin)
|
||||
ymin = max(0, ymin)
|
||||
xmax = min(w, xmax)
|
||||
ymax = min(h, ymax)
|
||||
image[:, ymin:ymax, xmin:xmax] = np.array(mask_color)[:, None, None]
|
||||
return image
|
||||
|
||||
return _cutout
|
||||
|
||||
|
||||
class TFImagenetLoader(Dataset):
|
||||
|
||||
def __init__(self, split, batchsize, idx, num_workers, rescale=1):
|
||||
IMAGENET_NUM_TRAIN_IMAGES = 1281167
|
||||
IMAGENET_NUM_VAL_IMAGES = 50000
|
||||
|
||||
self.rescale = rescale
|
||||
|
||||
if split == "train":
|
||||
im_length = IMAGENET_NUM_TRAIN_IMAGES
|
||||
records_to_skip = im_length * idx // num_workers
|
||||
records_to_read = im_length * (idx + 1) // num_workers - records_to_skip
|
||||
else:
|
||||
im_length = IMAGENET_NUM_VAL_IMAGES
|
||||
|
||||
self.curr_sample = 0
|
||||
|
||||
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
|
||||
with open(index_path) as f:
|
||||
metadata = json.load(f)
|
||||
counts = metadata['record_counts']
|
||||
|
||||
if split == 'train':
|
||||
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
|
||||
|
||||
result_records_to_skip = None
|
||||
files = []
|
||||
for filename in file_names:
|
||||
records_in_file = counts[filename]
|
||||
if records_to_skip >= records_in_file:
|
||||
records_to_skip -= records_in_file
|
||||
continue
|
||||
elif records_to_read > 0:
|
||||
if result_records_to_skip is None:
|
||||
# Record the number to skip in the first file
|
||||
result_records_to_skip = records_to_skip
|
||||
files.append(filename)
|
||||
records_to_read -= (records_in_file - records_to_skip)
|
||||
records_to_skip = 0
|
||||
else:
|
||||
break
|
||||
else:
|
||||
files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
|
||||
|
||||
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
|
||||
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
|
||||
|
||||
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
|
||||
ds = ds.apply(tf.data.TFRecordDataset)
|
||||
ds = ds.take(im_length)
|
||||
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
|
||||
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
|
||||
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
|
||||
ds = ds.prefetch(buffer_size=2)
|
||||
|
||||
ds_iterator = ds.make_initializable_iterator()
|
||||
labels, images = ds_iterator.get_next()
|
||||
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
|
||||
self.labels = labels
|
||||
|
||||
config = tf.ConfigProto(device_count = {'GPU': 0})
|
||||
sess = tf.Session(config=config)
|
||||
sess.run(ds_iterator.initializer)
|
||||
|
||||
self.im_length = im_length // batchsize
|
||||
|
||||
self.sess = sess
|
||||
|
||||
def __next__(self):
|
||||
self.curr_sample += 1
|
||||
|
||||
sess = self.sess
|
||||
|
||||
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
|
||||
label, im = sess.run([self.labels, self.images])
|
||||
im = im * self.rescale
|
||||
label = np.eye(1000)[label.squeeze() - 1]
|
||||
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
|
||||
return im_corrupt, im, label
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return self.im_length
|
||||
|
||||
class CelebA(Dataset):
|
||||
|
||||
def __init__(self):
|
||||
self.path = "/root/data/img_align_celeba"
|
||||
self.ims = os.listdir(self.path)
|
||||
self.ims = [osp.join(self.path, im) for im in self.ims]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ims)
|
||||
|
||||
def __getitem__(self, index):
|
||||
label = 1
|
||||
|
||||
if FLAGS.single:
|
||||
index = 0
|
||||
|
||||
path = self.ims[index]
|
||||
im = imread(path)
|
||||
im = imresize(im, (32, 32))
|
||||
image_size = 32
|
||||
im = im / 255.
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0, 1, size=(image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Cifar10(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
train=True,
|
||||
full=False,
|
||||
augment=False,
|
||||
noise=True,
|
||||
rescale=1.0):
|
||||
|
||||
if augment:
|
||||
transform_list = [
|
||||
torchvision.transforms.RandomCrop(32, padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor(),
|
||||
]
|
||||
|
||||
if FLAGS.cutout:
|
||||
transform_list.append(cutout())
|
||||
|
||||
transform = transforms.Compose(transform_list)
|
||||
else:
|
||||
transform = transforms.ToTensor()
|
||||
|
||||
self.full = full
|
||||
self.data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
self.test_data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=False,
|
||||
download=True)
|
||||
self.one_hot_map = np.eye(10)
|
||||
self.noise = noise
|
||||
self.rescale = rescale
|
||||
|
||||
def __len__(self):
|
||||
|
||||
if self.full:
|
||||
return len(self.data) + len(self.test_data)
|
||||
else:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not FLAGS.single:
|
||||
if self.full:
|
||||
if index >= len(self.data):
|
||||
im, label = self.test_data[index - len(self.data)]
|
||||
else:
|
||||
im, label = self.data[index]
|
||||
else:
|
||||
im, label = self.data[index]
|
||||
else:
|
||||
im, label = self.data[0]
|
||||
|
||||
im = np.transpose(im, (1, 2, 0)).numpy()
|
||||
image_size = 32
|
||||
label = self.one_hot_map[label]
|
||||
|
||||
im = im * 255 / 256
|
||||
|
||||
if self.noise:
|
||||
im = im * self.rescale + \
|
||||
np.random.uniform(0, self.rescale * 1 / 256., im.shape)
|
||||
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, self.rescale, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Cifar100(Dataset):
|
||||
def __init__(self, train=True, augment=False):
|
||||
|
||||
if augment:
|
||||
transform_list = [
|
||||
torchvision.transforms.RandomCrop(32, padding=4),
|
||||
torchvision.transforms.RandomHorizontalFlip(),
|
||||
torchvision.transforms.ToTensor(),
|
||||
]
|
||||
|
||||
if FLAGS.cutout:
|
||||
transform_list.append(cutout())
|
||||
|
||||
transform = transforms.Compose(transform_list)
|
||||
else:
|
||||
transform = transforms.ToTensor()
|
||||
|
||||
self.data = CIFAR100(
|
||||
"/root/cifar100",
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
self.one_hot_map = np.eye(100)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not FLAGS.single:
|
||||
im, label = self.data[index]
|
||||
else:
|
||||
im, label = self.data[0]
|
||||
|
||||
im = np.transpose(im, (1, 2, 0)).numpy()
|
||||
image_size = 32
|
||||
label = self.one_hot_map[label]
|
||||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Svhn(Dataset):
|
||||
def __init__(self, train=True, augment=False):
|
||||
|
||||
transform = transforms.ToTensor()
|
||||
|
||||
self.data = SVHN("/root/svhn", transform=transform, download=True)
|
||||
self.one_hot_map = np.eye(10)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not FLAGS.single:
|
||||
im, label = self.data[index]
|
||||
else:
|
||||
em, label = self.data[0]
|
||||
|
||||
im = np.transpose(im, (1, 2, 0)).numpy()
|
||||
image_size = 32
|
||||
label = self.one_hot_map[label]
|
||||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Mnist(Dataset):
|
||||
def __init__(self, train=True, rescale=1.0):
|
||||
self.data = MNIST(
|
||||
"/root/mnist",
|
||||
transform=transforms.ToTensor(),
|
||||
download=True, train=train)
|
||||
self.labels = np.eye(10)
|
||||
self.rescale = rescale
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
im, label = self.data[index]
|
||||
label = self.labels[label]
|
||||
im = im.squeeze()
|
||||
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
|
||||
# im = im.numpy() / 2 + 0.2
|
||||
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
|
||||
im = im * self.rescale
|
||||
image_size = 28
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class DSprites(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
cond_size=False,
|
||||
cond_shape=False,
|
||||
cond_pos=False,
|
||||
cond_rot=False):
|
||||
dat = np.load(FLAGS.dsprites_path)
|
||||
|
||||
if FLAGS.dshape_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
|
||||
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
self.label = self.label[:, 1:2]
|
||||
elif FLAGS.dpos_only:
|
||||
l = dat['latents_values']
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 1] == 1) & (
|
||||
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
self.label = self.label[:, 4:] + 0.5
|
||||
elif FLAGS.dsize_only:
|
||||
l = dat['latents_values']
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
self.label = (self.label[:, 2:3])
|
||||
elif FLAGS.drot_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
self.label = (self.label[:, 3:4])
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
elif FLAGS.dsprites_restrict:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
|
||||
|
||||
self.data = dat['imgs'][mask]
|
||||
self.label = dat['latents_values'][mask]
|
||||
else:
|
||||
self.data = dat['imgs']
|
||||
self.label = dat['latents_values']
|
||||
|
||||
if cond_size:
|
||||
self.label = self.label[:, 2:3]
|
||||
elif cond_shape:
|
||||
self.label = self.label[:, 1:2]
|
||||
elif cond_pos:
|
||||
self.label = self.label[:, 4:]
|
||||
elif cond_rot:
|
||||
self.label = self.label[:, 3:4]
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
else:
|
||||
self.label = self.label[:, 1:2]
|
||||
|
||||
self.identity = np.eye(3)
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
im = self.data[index]
|
||||
image_size = 64
|
||||
|
||||
if not (
|
||||
FLAGS.dpos_only or FLAGS.dsize_only) and (
|
||||
not FLAGS.cond_size) and (
|
||||
not FLAGS.cond_pos) and (
|
||||
not FLAGS.cond_rot) and (
|
||||
not FLAGS.drot_only):
|
||||
label = self.identity[self.label[index].astype(
|
||||
np.int32) - 1].squeeze()
|
||||
else:
|
||||
label = self.label[index]
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Imagenet(Dataset):
|
||||
def __init__(self, train=True, augment=False):
|
||||
|
||||
if train:
|
||||
for i in range(1, 11):
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'train_data_batch_{}'.format(i)),
|
||||
'rb'))
|
||||
if i == 1:
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
else:
|
||||
labels.extend(f['labels'])
|
||||
data = np.vstack((data, f['data']))
|
||||
else:
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'val_data'),
|
||||
'rb'))
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
|
||||
self.labels = labels
|
||||
self.data = data
|
||||
self.one_hot_map = np.eye(1000)
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not FLAGS.single:
|
||||
im, label = self.data[index], self.labels[index]
|
||||
else:
|
||||
im, label = self.data[0], self.labels[0]
|
||||
|
||||
label -= 1
|
||||
|
||||
im = im.reshape((3, 32, 32)) / 255
|
||||
im = im.transpose((1, 2, 0))
|
||||
image_size = 32
|
||||
label = self.one_hot_map[label]
|
||||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Textures(Dataset):
|
||||
def __init__(self, train=True, augment=False):
|
||||
self.dataset = ImageFolder("/mnt/nfs/yilundu/data/dtd/images")
|
||||
|
||||
def __len__(self):
|
||||
return 2 * len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
idx = index % (len(self.dataset))
|
||||
im, label = self.dataset[idx]
|
||||
|
||||
im = np.array(im)[:32, :32] / 255
|
||||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
|
||||
return im, im, label
|
Loading…
Add table
Add a link
Reference in a new issue