chores: refactor for the new ai research, add linter, gh action, etc (#27)

This commit is contained in:
Marina von Steinkirch, PhD 2025-08-13 21:49:46 +08:00 committed by von-steinkirch
parent fb4ab80dc3
commit d5467e559f
40 changed files with 5177 additions and 2476 deletions

View file

@ -1,42 +1,48 @@
from tensorflow.python.platform import flags
from tensorflow.contrib.data.python.ops import batching, threadpool
import tensorflow as tf
import json
from torch.utils.data import Dataset
import pickle
import os.path as osp
import os
import numpy as np
import os.path as osp
import pickle
import time
from scipy.misc import imread, imresize
from skimage.color import rgb2grey
from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
from torchvision import transforms
from imagenet_preprocessing import ImagenetPreprocessor
import numpy as np
import tensorflow as tf
import torch
import torchvision
from imagenet_preprocessing import ImagenetPreprocessor
from scipy.misc import imread, imresize
from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.platform import flags
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, ImageFolder
FLAGS = flags.FLAGS
ROOT_DIR = "./results"
# Dataset Options
flags.DEFINE_string('dsprites_path',
'/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
'path to dsprites characters')
flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
flags.DEFINE_string(
"dsprites_path",
"/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
"path to dsprites characters",
)
flags.DEFINE_string(
"imagenet_datadir", "/root/imagenet_big", "whether cutoff should always in image"
)
flags.DEFINE_bool("dshape_only", False, "fix all factors except for shapes")
flags.DEFINE_bool("dpos_only", False, "fix all factors except for positions of shapes")
flags.DEFINE_bool("dsize_only", False, "fix all factors except for size of objects")
flags.DEFINE_bool("drot_only", False, "fix all factors except for rotation of objects")
flags.DEFINE_bool(
"dsprites_restrict", False, "fix all factors except for rotation of objects"
)
flags.DEFINE_string("imagenet_path", "/root/imagenet", "path to imagenet images")
# Data augmentation options
flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
flags.DEFINE_bool("cutout_inside", False, "whether cutoff should always in image")
flags.DEFINE_float("cutout_prob", 1.0, "probability of using cutout")
flags.DEFINE_integer("cutout_mask_size", 16, "size of cutout")
flags.DEFINE_bool("cutout", False, "whether to add cutout regularizer to data")
def cutout(mask_color=(0, 0, 0)):
@ -91,13 +97,15 @@ class TFImagenetLoader(Dataset):
self.curr_sample = 0
index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
index_path = osp.join(FLAGS.imagenet_datadir, "index.json")
with open(index_path) as f:
metadata = json.load(f)
counts = metadata['record_counts']
counts = metadata["record_counts"]
if split == 'train':
file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
if split == "train":
file_names = list(
sorted([x for x in counts.keys() if x.startswith("train")])
)
result_records_to_skip = None
files = []
@ -111,30 +119,44 @@ class TFImagenetLoader(Dataset):
# Record the number to skip in the first file
result_records_to_skip = records_to_skip
files.append(filename)
records_to_read -= (records_in_file - records_to_skip)
records_to_read -= records_in_file - records_to_skip
records_to_skip = 0
else:
break
else:
files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
files = list(
sorted([x for x in counts.keys() if x.startswith("validation")])
)
files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
preprocess_function = ImagenetPreprocessor(
128, dtype=tf.float32, train=False
).parse_and_preprocess
ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
ds = tf.data.TFRecordDataset.from_generator(
lambda: files, output_types=tf.string
)
ds = ds.apply(tf.data.TFRecordDataset)
ds = ds.take(im_length)
ds = ds.prefetch(buffer_size=FLAGS.batch_size)
ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
ds = ds.apply(
batching.map_and_batch(
map_func=preprocess_function,
batch_size=FLAGS.batch_size,
num_parallel_batches=4,
)
)
ds = ds.prefetch(buffer_size=2)
ds_iterator = ds.make_initializable_iterator()
labels, images = ds_iterator.get_next()
self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
self.images = tf.clip_by_value(
images / 256 + tf.random_uniform(tf.shape(images), 0, 1.0 / 256), 0.0, 1.0
)
self.labels = labels
config = tf.ConfigProto(device_count = {'GPU': 0})
config = tf.ConfigProto(device_count={"GPU": 0})
sess = tf.Session(config=config)
sess.run(ds_iterator.initializer)
@ -147,11 +169,17 @@ class TFImagenetLoader(Dataset):
sess = self.sess
im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
im_corrupt = np.random.uniform(
0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)
)
label, im = sess.run([self.labels, self.images])
im = im * self.rescale
label = np.eye(1000)[label.squeeze() - 1]
im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
im, im_corrupt, label = (
torch.from_numpy(im),
torch.from_numpy(im_corrupt),
torch.from_numpy(label),
)
return im_corrupt, im, label
def __iter__(self):
@ -160,6 +188,7 @@ class TFImagenetLoader(Dataset):
def __len__(self):
return self.im_length
class CelebA(Dataset):
def __init__(self):
@ -180,25 +209,18 @@ class CelebA(Dataset):
im = imread(path)
im = imresize(im, (32, 32))
image_size = 32
im = im / 255.
im = im / 255.0
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0, 1, size=(image_size, image_size, 3))
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0, 1, size=(image_size, image_size, 3))
return im_corrupt, im, label
class Cifar10(Dataset):
def __init__(
self,
train=True,
full=False,
augment=False,
noise=True,
rescale=1.0):
def __init__(self, train=True, full=False, augment=False, noise=True, rescale=1.0):
if augment:
transform_list = [
@ -215,16 +237,10 @@ class Cifar10(Dataset):
transform = transforms.ToTensor()
self.full = full
self.data = CIFAR10(
ROOT_DIR,
transform=transform,
train=train,
download=True)
self.data = CIFAR10(ROOT_DIR, transform=transform, train=train, download=True)
self.test_data = CIFAR10(
ROOT_DIR,
transform=transform,
train=False,
download=True)
ROOT_DIR, transform=transform, train=False, download=True
)
self.one_hot_map = np.eye(10)
self.noise = noise
self.rescale = rescale
@ -255,16 +271,18 @@ class Cifar10(Dataset):
im = im * 255 / 256
if self.noise:
im = im * self.rescale + \
np.random.uniform(0, self.rescale * 1 / 256., im.shape)
im = im * self.rescale + np.random.uniform(
0, self.rescale * 1 / 256.0, im.shape
)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(
0.0, self.rescale, (image_size, image_size, 3))
0.0, self.rescale, (image_size, image_size, 3)
)
return im_corrupt, im, label
@ -287,10 +305,8 @@ class Cifar100(Dataset):
transform = transforms.ToTensor()
self.data = CIFAR100(
"/root/cifar100",
transform=transform,
train=train,
download=True)
"/root/cifar100", transform=transform, train=train, download=True
)
self.one_hot_map = np.eye(100)
def __len__(self):
@ -308,11 +324,10 @@ class Cifar100(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
@ -340,11 +355,10 @@ class Svhn(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label
@ -352,9 +366,8 @@ class Svhn(Dataset):
class Mnist(Dataset):
def __init__(self, train=True, rescale=1.0):
self.data = MNIST(
"/root/mnist",
transform=transforms.ToTensor(),
download=True, train=train)
"/root/mnist", transform=transforms.ToTensor(), download=True, train=train
)
self.labels = np.eye(10)
self.rescale = rescale
@ -367,13 +380,13 @@ class Mnist(Dataset):
im = im.squeeze()
# im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
# im = im.numpy() / 2 + 0.2
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
im = im.numpy() / 256 * 255 + np.random.uniform(0, 1.0 / 256, (28, 28))
im = im * self.rescale
image_size = 28
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random':
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
return im_corrupt, im, label
@ -381,54 +394,63 @@ class Mnist(Dataset):
class DSprites(Dataset):
def __init__(
self,
cond_size=False,
cond_shape=False,
cond_pos=False,
cond_rot=False):
self, cond_size=False, cond_shape=False, cond_pos=False, cond_rot=False
):
dat = np.load(FLAGS.dsprites_path)
if FLAGS.dshape_only:
l = dat['latents_values']
mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
l = dat["latents_values"]
mask = (
(l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 2] == 0.5)
& (l[:, 3] == 30 * np.pi / 39)
)
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
self.label = self.label[:, 1:2]
elif FLAGS.dpos_only:
l = dat['latents_values']
l = dat["latents_values"]
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 1] == 1) & (
l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (100, 1))
mask = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (100, 1))
self.label = self.label[:, 4:] + 0.5
elif FLAGS.dsize_only:
l = dat['latents_values']
l = dat["latents_values"]
# mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (10000, 1))
self.label = (self.label[:, 2:3])
mask = (
(l[:, 3] == 30 * np.pi / 39)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 1] == 1)
)
self.data = np.tile(dat["imgs"][mask], (10000, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (10000, 1))
self.label = self.label[:, 2:3]
elif FLAGS.drot_only:
l = dat['latents_values']
mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
self.label = np.tile(dat['latents_values'][mask], (100, 1))
self.label = (self.label[:, 3:4])
l = dat["latents_values"]
mask = (
(l[:, 2] == 0.5)
& (l[:, 4] == 16 / 31)
& (l[:, 5] == 16 / 31)
& (l[:, 1] == 1)
)
self.data = np.tile(dat["imgs"][mask], (100, 1, 1))
self.label = np.tile(dat["latents_values"][mask], (100, 1))
self.label = self.label[:, 3:4]
self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1)
[np.cos(self.label), np.sin(self.label)], axis=1
)
elif FLAGS.dsprites_restrict:
l = dat['latents_values']
l = dat["latents_values"]
mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
self.data = dat['imgs'][mask]
self.label = dat['latents_values'][mask]
self.data = dat["imgs"][mask]
self.label = dat["latents_values"][mask]
else:
self.data = dat['imgs']
self.label = dat['latents_values']
self.data = dat["imgs"]
self.label = dat["latents_values"]
if cond_size:
self.label = self.label[:, 2:3]
@ -439,7 +461,8 @@ class DSprites(Dataset):
elif cond_rot:
self.label = self.label[:, 3:4]
self.label = np.concatenate(
[np.cos(self.label), np.sin(self.label)], axis=1)
[np.cos(self.label), np.sin(self.label)], axis=1
)
else:
self.label = self.label[:, 1:2]
@ -452,20 +475,20 @@ class DSprites(Dataset):
im = self.data[index]
image_size = 64
if not (
FLAGS.dpos_only or FLAGS.dsize_only) and (
not FLAGS.cond_size) and (
not FLAGS.cond_pos) and (
not FLAGS.cond_rot) and (
not FLAGS.drot_only):
label = self.identity[self.label[index].astype(
np.int32) - 1].squeeze()
if (
not (FLAGS.dpos_only or FLAGS.dsize_only)
and (not FLAGS.cond_size)
and (not FLAGS.cond_pos)
and (not FLAGS.cond_rot)
and (not FLAGS.drot_only)
):
label = self.identity[self.label[index].astype(np.int32) - 1].squeeze()
else:
label = self.label[index]
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
elif FLAGS.datasource == 'random':
elif FLAGS.datasource == "random":
im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
return im_corrupt, im, label
@ -478,25 +501,20 @@ class Imagenet(Dataset):
for i in range(1, 11):
f = pickle.load(
open(
osp.join(
FLAGS.imagenet_path,
'train_data_batch_{}'.format(i)),
'rb'))
osp.join(FLAGS.imagenet_path, "train_data_batch_{}".format(i)),
"rb",
)
)
if i == 1:
labels = f['labels']
data = f['data']
labels = f["labels"]
data = f["data"]
else:
labels.extend(f['labels'])
data = np.vstack((data, f['data']))
labels.extend(f["labels"])
data = np.vstack((data, f["data"]))
else:
f = pickle.load(
open(
osp.join(
FLAGS.imagenet_path,
'val_data'),
'rb'))
labels = f['labels']
data = f['data']
f = pickle.load(open(osp.join(FLAGS.imagenet_path, "val_data"), "rb"))
labels = f["labels"]
data = f["data"]
self.labels = labels
self.data = data
@ -520,11 +538,10 @@ class Imagenet(Dataset):
im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
np.random.seed((index + int(time.time() * 1e7)) % 2**32)
if FLAGS.datasource == 'default':
if FLAGS.datasource == "default":
im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
elif FLAGS.datasource == 'random':
im_corrupt = np.random.uniform(
0.0, 1.0, (image_size, image_size, 3))
elif FLAGS.datasource == "random":
im_corrupt = np.random.uniform(0.0, 1.0, (image_size, image_size, 3))
return im_corrupt, im, label