mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-08-17 10:40:13 -04:00
chores: refactor for the new ai research, add linter, gh action, etc (#27)
This commit is contained in:
parent
fb4ab80dc3
commit
d5467e559f
40 changed files with 5177 additions and 2476 deletions
311
EBMs/data.py
311
EBMs/data.py
|
@ -1,42 +1,48 @@
|
|||
from tensorflow.python.platform import flags
|
||||
from tensorflow.contrib.data.python.ops import batching, threadpool
|
||||
import tensorflow as tf
|
||||
import json
|
||||
from torch.utils.data import Dataset
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import os
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import time
|
||||
from scipy.misc import imread, imresize
|
||||
from skimage.color import rgb2grey
|
||||
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
|
||||
from torchvision import transforms
|
||||
from imagenet_preprocessing import ImagenetPreprocessor
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import torchvision
|
||||
from imagenet_preprocessing import ImagenetPreprocessor
|
||||
from scipy.misc import imread, imresize
|
||||
from tensorflow.contrib.data.python.ops import batching
|
||||
from tensorflow.python.platform import flags
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, ImageFolder
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
ROOT_DIR = "./results"
|
||||
|
||||
# Dataset Options
|
||||
flags.DEFINE_string('dsprites_path',
|
||||
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
|
||||
'path to dsprites characters')
|
||||
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
|
||||
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
|
||||
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
|
||||
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
|
||||
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
|
||||
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
|
||||
flags.DEFINE_string(
|
||||
"dsprites_path",
|
||||
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
|
||||
"path to dsprites characters",
|
||||
)
|
||||
flags.DEFINE_string(
|
||||
"imagenet_datadir", "/root/imagenet_big", "whether cutoff should always in image"
|
||||
)
|
||||
flags.DEFINE_bool("dshape_only", False, "fix all factors except for shapes")
|
||||
flags.DEFINE_bool("dpos_only", False, "fix all factors except for positions of shapes")
|
||||
flags.DEFINE_bool("dsize_only", False, "fix all factors except for size of objects")
|
||||
flags.DEFINE_bool("drot_only", False, "fix all factors except for rotation of objects")
|
||||
flags.DEFINE_bool(
|
||||
"dsprites_restrict", False, "fix all factors except for rotation of objects"
|
||||
)
|
||||
flags.DEFINE_string("imagenet_path", "/root/imagenet", "path to imagenet images")
|
||||
|
||||
|
||||
# Data augmentation options
|
||||
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
|
||||
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
|
||||
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
|
||||
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
|
||||
flags.DEFINE_bool("cutout_inside", False, "whether cutoff should always in image")
|
||||
flags.DEFINE_float("cutout_prob", 1.0, "probability of using cutout")
|
||||
flags.DEFINE_integer("cutout_mask_size", 16, "size of cutout")
|
||||
flags.DEFINE_bool("cutout", False, "whether to add cutout regularizer to data")
|
||||
|
||||
|
||||
def cutout(mask_color=(0, 0, 0)):
|
||||
|
@ -91,13 +97,15 @@ class TFImagenetLoader(Dataset):
|
|||
|
||||
self.curr_sample = 0
|
||||
|
||||
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
|
||||
index_path = osp.join(FLAGS.imagenet_datadir, "index.json")
|
||||
with open(index_path) as f:
|
||||
metadata = json.load(f)
|
||||
counts = metadata['record_counts']
|
||||
counts = metadata["record_counts"]
|
||||
|
||||
if split == 'train':
|
||||
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
|
||||
if split == "train":
|
||||
file_names = list(
|
||||
sorted([x for x in counts.keys() if x.startswith("train")])
|
||||
)
|
||||
|
||||
result_records_to_skip = None
|
||||
files = []
|
||||
|
@ -111,30 +119,44 @@ class TFImagenetLoader(Dataset):
|
|||
# Record the number to skip in the first file
|
||||
result_records_to_skip = records_to_skip
|
||||
files.append(filename)
|
||||
records_to_read -= (records_in_file - records_to_skip)
|
||||
records_to_read -= records_in_file - records_to_skip
|
||||
records_to_skip = 0
|
||||
else:
|
||||
break
|
||||
else:
|
||||
files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
|
||||
files = list(
|
||||
sorted([x for x in counts.keys() if x.startswith("validation")])
|
||||
)
|
||||
|
||||
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
|
||||
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
|
||||
preprocess_function = ImagenetPreprocessor(
|
||||
128, dtype=tf.float32, train=False
|
||||
).parse_and_preprocess
|
||||
|
||||
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
|
||||
ds = tf.data.TFRecordDataset.from_generator(
|
||||
lambda: files, output_types=tf.string
|
||||
)
|
||||
ds = ds.apply(tf.data.TFRecordDataset)
|
||||
ds = ds.take(im_length)
|
||||
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
|
||||
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
|
||||
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
|
||||
ds = ds.apply(
|
||||
batching.map_and_batch(
|
||||
map_func=preprocess_function,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_parallel_batches=4,
|
||||
)
|
||||
)
|
||||
ds = ds.prefetch(buffer_size=2)
|
||||
|
||||
ds_iterator = ds.make_initializable_iterator()
|
||||
labels, images = ds_iterator.get_next()
|
||||
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
|
||||
self.images = tf.clip_by_value(
|
||||
images / 256 + tf.random_uniform(tf.shape(images), 0, 1.0 / 256), 0.0, 1.0
|
||||
)
|
||||
self.labels = labels
|
||||
|
||||
config = tf.ConfigProto(device_count = {'GPU': 0})
|
||||
config = tf.ConfigProto(device_count={"GPU": 0})
|
||||
sess = tf.Session(config=config)
|
||||
sess.run(ds_iterator.initializer)
|
||||
|
||||
|
@ -147,11 +169,17 @@ class TFImagenetLoader(Dataset):
|
|||
|
||||
sess = self.sess
|
||||
|
||||
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
|
||||
im_corrupt = np.random.uniform(
|
||||
0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)
|
||||
)
|
||||
label, im = sess.run([self.labels, self.images])
|
||||
im = im * self.rescale
|
||||
label = np.eye(1000)[label.squeeze() - 1]
|
||||
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
|
||||
im, im_corrupt, label = (
|
||||
torch.from_numpy(im),
|
||||
torch.from_numpy(im_corrupt),
|
||||
torch.from_numpy(label),
|
||||
)
|
||||
return im_corrupt, im, label
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -160,6 +188,7 @@ class TFImagenetLoader(Dataset):
|
|||
def __len__(self):
|
||||
return self.im_length
|
||||
|
||||
|
||||
class CelebA(Dataset):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -180,25 +209,18 @@ class CelebA(Dataset):
|
|||
im = imread(path)
|
||||
im = imresize(im, (32, 32))
|
||||
image_size = 32
|
||||
im = im / 255.
|
||||
im = im / 255.0
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0, 1, size=(image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0, 1, size=(image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
||||
class Cifar10(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
train=True,
|
||||
full=False,
|
||||
augment=False,
|
||||
noise=True,
|
||||
rescale=1.0):
|
||||
def __init__(self, train=True, full=False, augment=False, noise=True, rescale=1.0):
|
||||
|
||||
if augment:
|
||||
transform_list = [
|
||||
|
@ -215,16 +237,10 @@ class Cifar10(Dataset):
|
|||
transform = transforms.ToTensor()
|
||||
|
||||
self.full = full
|
||||
self.data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
self.data = CIFAR10(ROOT_DIR, transform=transform, train=train, download=True)
|
||||
self.test_data = CIFAR10(
|
||||
ROOT_DIR,
|
||||
transform=transform,
|
||||
train=False,
|
||||
download=True)
|
||||
ROOT_DIR, transform=transform, train=False, download=True
|
||||
)
|
||||
self.one_hot_map = np.eye(10)
|
||||
self.noise = noise
|
||||
self.rescale = rescale
|
||||
|
@ -255,16 +271,18 @@ class Cifar10(Dataset):
|
|||
im = im * 255 / 256
|
||||
|
||||
if self.noise:
|
||||
im = im * self.rescale + \
|
||||
np.random.uniform(0, self.rescale * 1 / 256., im.shape)
|
||||
im = im * self.rescale + np.random.uniform(
|
||||
0, self.rescale * 1 / 256.0, im.shape
|
||||
)
|
||||
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, self.rescale, (image_size, image_size, 3))
|
||||
0.0, self.rescale, (image_size, image_size, 3)
|
||||
)
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
@ -287,10 +305,8 @@ class Cifar100(Dataset):
|
|||
transform = transforms.ToTensor()
|
||||
|
||||
self.data = CIFAR100(
|
||||
"/root/cifar100",
|
||||
transform=transform,
|
||||
train=train,
|
||||
download=True)
|
||||
"/root/cifar100", transform=transform, train=train, download=True
|
||||
)
|
||||
self.one_hot_map = np.eye(100)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -308,11 +324,10 @@ class Cifar100(Dataset):
|
|||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
@ -340,11 +355,10 @@ class Svhn(Dataset):
|
|||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
@ -352,9 +366,8 @@ class Svhn(Dataset):
|
|||
class Mnist(Dataset):
|
||||
def __init__(self, train=True, rescale=1.0):
|
||||
self.data = MNIST(
|
||||
"/root/mnist",
|
||||
transform=transforms.ToTensor(),
|
||||
download=True, train=train)
|
||||
"/root/mnist", transform=transforms.ToTensor(), download=True, train=train
|
||||
)
|
||||
self.labels = np.eye(10)
|
||||
self.rescale = rescale
|
||||
|
||||
|
@ -367,13 +380,13 @@ class Mnist(Dataset):
|
|||
im = im.squeeze()
|
||||
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
|
||||
# im = im.numpy() / 2 + 0.2
|
||||
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
|
||||
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1.0 / 256, (28, 28))
|
||||
im = im * self.rescale
|
||||
image_size = 28
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
@ -381,54 +394,63 @@ class Mnist(Dataset):
|
|||
|
||||
class DSprites(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
cond_size=False,
|
||||
cond_shape=False,
|
||||
cond_pos=False,
|
||||
cond_rot=False):
|
||||
self, cond_size=False, cond_shape=False, cond_pos=False, cond_rot=False
|
||||
):
|
||||
dat = np.load(FLAGS.dsprites_path)
|
||||
|
||||
if FLAGS.dshape_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
|
||||
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
l = dat["latents_values"]
|
||||
mask = (
|
||||
(l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 2] == 0.5)
|
||||
& (l[:, 3] == 30 * np.pi / 39)
|
||||
)
|
||||
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
|
||||
self.label = self.label[:, 1:2]
|
||||
elif FLAGS.dpos_only:
|
||||
l = dat['latents_values']
|
||||
l = dat["latents_values"]
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 1] == 1) & (
|
||||
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
mask = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
|
||||
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (100, 1))
|
||||
self.label = self.label[:, 4:] + 0.5
|
||||
elif FLAGS.dsize_only:
|
||||
l = dat['latents_values']
|
||||
l = dat["latents_values"]
|
||||
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
|
||||
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
|
||||
self.label = (self.label[:, 2:3])
|
||||
mask = (
|
||||
(l[:, 3] == 30 * np.pi / 39)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 1] == 1)
|
||||
)
|
||||
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
|
||||
self.label = self.label[:, 2:3]
|
||||
elif FLAGS.drot_only:
|
||||
l = dat['latents_values']
|
||||
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
|
||||
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
|
||||
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat['latents_values'][mask], (100, 1))
|
||||
self.label = (self.label[:, 3:4])
|
||||
l = dat["latents_values"]
|
||||
mask = (
|
||||
(l[:, 2] == 0.5)
|
||||
& (l[:, 4] == 16 / 31)
|
||||
& (l[:, 5] == 16 / 31)
|
||||
& (l[:, 1] == 1)
|
||||
)
|
||||
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
|
||||
self.label = np.tile(dat["latents_values"][mask], (100, 1))
|
||||
self.label = self.label[:, 3:4]
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1
|
||||
)
|
||||
elif FLAGS.dsprites_restrict:
|
||||
l = dat['latents_values']
|
||||
l = dat["latents_values"]
|
||||
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
|
||||
|
||||
self.data = dat['imgs'][mask]
|
||||
self.label = dat['latents_values'][mask]
|
||||
self.data = dat["imgs"][mask]
|
||||
self.label = dat["latents_values"][mask]
|
||||
else:
|
||||
self.data = dat['imgs']
|
||||
self.label = dat['latents_values']
|
||||
self.data = dat["imgs"]
|
||||
self.label = dat["latents_values"]
|
||||
|
||||
if cond_size:
|
||||
self.label = self.label[:, 2:3]
|
||||
|
@ -439,7 +461,8 @@ class DSprites(Dataset):
|
|||
elif cond_rot:
|
||||
self.label = self.label[:, 3:4]
|
||||
self.label = np.concatenate(
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1)
|
||||
[np.cos(self.label), np.sin(self.label)], axis=1
|
||||
)
|
||||
else:
|
||||
self.label = self.label[:, 1:2]
|
||||
|
||||
|
@ -452,20 +475,20 @@ class DSprites(Dataset):
|
|||
im = self.data[index]
|
||||
image_size = 64
|
||||
|
||||
if not (
|
||||
FLAGS.dpos_only or FLAGS.dsize_only) and (
|
||||
not FLAGS.cond_size) and (
|
||||
not FLAGS.cond_pos) and (
|
||||
not FLAGS.cond_rot) and (
|
||||
not FLAGS.drot_only):
|
||||
label = self.identity[self.label[index].astype(
|
||||
np.int32) - 1].squeeze()
|
||||
if (
|
||||
not (FLAGS.dpos_only or FLAGS.dsize_only)
|
||||
and (not FLAGS.cond_size)
|
||||
and (not FLAGS.cond_pos)
|
||||
and (not FLAGS.cond_rot)
|
||||
and (not FLAGS.drot_only)
|
||||
):
|
||||
label = self.identity[self.label[index].astype(np.int32) - 1].squeeze()
|
||||
else:
|
||||
label = self.label[index]
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
|
||||
elif FLAGS.datasource == 'random':
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
@ -478,25 +501,20 @@ class Imagenet(Dataset):
|
|||
for i in range(1, 11):
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'train_data_batch_{}'.format(i)),
|
||||
'rb'))
|
||||
osp.join(FLAGS.imagenet_path, "train_data_batch_{}".format(i)),
|
||||
"rb",
|
||||
)
|
||||
)
|
||||
if i == 1:
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
labels = f["labels"]
|
||||
data = f["data"]
|
||||
else:
|
||||
labels.extend(f['labels'])
|
||||
data = np.vstack((data, f['data']))
|
||||
labels.extend(f["labels"])
|
||||
data = np.vstack((data, f["data"]))
|
||||
else:
|
||||
f = pickle.load(
|
||||
open(
|
||||
osp.join(
|
||||
FLAGS.imagenet_path,
|
||||
'val_data'),
|
||||
'rb'))
|
||||
labels = f['labels']
|
||||
data = f['data']
|
||||
f = pickle.load(open(osp.join(FLAGS.imagenet_path, "val_data"), "rb"))
|
||||
labels = f["labels"]
|
||||
data = f["data"]
|
||||
|
||||
self.labels = labels
|
||||
self.data = data
|
||||
|
@ -520,11 +538,10 @@ class Imagenet(Dataset):
|
|||
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
|
||||
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
|
||||
|
||||
if FLAGS.datasource == 'default':
|
||||
if FLAGS.datasource == "default":
|
||||
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
|
||||
elif FLAGS.datasource == 'random':
|
||||
im_corrupt = np.random.uniform(
|
||||
0.0, 1.0, (image_size, image_size, 3))
|
||||
elif FLAGS.datasource == "random":
|
||||
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
|
||||
|
||||
return im_corrupt, im, label
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue