chores: refactor for the new ai research, add linter, gh action, etc (#27)

This commit is contained in:
Marina von Steinkirch, PhD 2025-08-13 21:49:46 +08:00 committed by von-steinkirch
parent fb4ab80dc3
commit d5467e559f
40 changed files with 5177 additions and 2476 deletions

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python3
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs.
The FID metric calculates the distance between two distributions of images.
Typically, we have summary statistics (mean & covariance matrix) of one
@ -14,28 +14,33 @@ the pool_3 layer of the inception net for generated samples and real world
samples respectivly.
See --help to see further details.
'''
"""
from __future__ import absolute_import, division, print_function
import numpy as np
import os
import gzip, pickle
import tensorflow as tf
from scipy.misc import imread
from scipy import linalg
import pathlib
import urllib
import tarfile
import urllib
import warnings
MODEL_DIR = '/tmp/imagenet'
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
import numpy as np
import tensorflow as tf
from scipy import linalg
from scipy.misc import imread
MODEL_DIR = "/tmp/imagenet"
DATA_URL = (
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
)
pool3 = None
class InvalidFIDException(Exception):
pass
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def get_fid_score(images, images_gt):
images = np.stack(images, 0)
images_gt = np.stack(images_gt, 0)
@ -52,34 +57,38 @@ def get_fid_score(images, images_gt):
def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile( pth, 'rb') as f:
with tf.gfile.FastGFile(pth, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString( f.read())
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
#-------------------------------------------------------------------------------
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name="FID_Inception_Net")
# -------------------------------------------------------------------------------
# code for handling inception net derived from
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
def _get_inception_layer(sess):
"""Prepares inception net for batched usage and returns pool_3 layer. """
layername = 'FID_Inception_Net/pool_3:0'
"""Prepares inception net for batched usage and returns pool_3 layer."""
layername = "FID_Inception_Net/pool_3:0"
pool3 = sess.graph.get_tensor_by_name(layername)
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def get_activations(images, sess, batch_size=50, verbose=False):
@ -100,23 +109,27 @@ def get_activations(images, sess, batch_size=50, verbose=False):
# inception_layer = _get_inception_layer(sess)
d0 = images.shape[0]
if batch_size > d0:
print("warning: batch size is bigger than the data size. setting batch size to data size")
print(
"warning: batch size is bigger than the data size. setting batch size to data size"
)
batch_size = d0
n_batches = d0//batch_size
n_used_imgs = n_batches*batch_size
pred_arr = np.empty((n_used_imgs,2048))
n_batches = d0 // batch_size
n_used_imgs = n_batches * batch_size
pred_arr = np.empty((n_used_imgs, 2048))
for i in range(n_batches):
if verbose:
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
start = i*batch_size
print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True)
start = i * batch_size
end = start + batch_size
batch = images[start:end]
pred = sess.run(pool3, {'ExpandDims:0': batch})
pred_arr[start:end] = pred.reshape(batch_size,-1)
pred = sess.run(pool3, {"ExpandDims:0": batch})
pred_arr[start:end] = pred.reshape(batch_size, -1)
if verbose:
print(" done")
return pred_arr
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
@ -147,15 +160,22 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
assert (
mu1.shape == mu2.shape
), "Training and test mean vectors have different lengths"
assert (
sigma1.shape == sigma2.shape
), "Training and test covariances have different dimensions"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
@ -170,7 +190,9 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
@ -193,47 +215,52 @@ def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
# The following functions aren't needed for calculating the FID
# they're just here to make this module work as a stand-alone script
# for calculating FID scores
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------
def check_or_download_inception(inception_path):
''' Checks if the path to the inception file is valid, or downloads
the file if it is not present. '''
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
"""Checks if the path to the inception file is valid, or downloads
the file if it is not present."""
INCEPTION_URL = (
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
)
if inception_path is None:
inception_path = '/tmp'
inception_path = "/tmp"
inception_path = pathlib.Path(inception_path)
model_file = inception_path / 'classify_image_graph_def.pb'
model_file = inception_path / "classify_image_graph_def.pb"
if not model_file.exists():
print("Downloading Inception model")
from urllib import request
import tarfile
from urllib import request
fn, _ = request.urlretrieve(INCEPTION_URL)
with tarfile.open(fn, mode='r') as f:
f.extract('classify_image_graph_def.pb', str(model_file.parent))
with tarfile.open(fn, mode="r") as f:
f.extract("classify_image_graph_def.pb", str(model_file.parent))
return str(model_file)
def _handle_path(path, sess):
if path.endswith('.npz'):
if path.endswith(".npz"):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
m, s = f["mu"][:], f["sigma"][:]
f.close()
else:
path = pathlib.Path(path)
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
m, s = calculate_activation_statistics(x, sess)
return m, s
def calculate_fid_given_paths(paths, inception_path):
''' Calculates the FID of two paths. '''
"""Calculates the FID of two paths."""
inception_path = check_or_download_inception(inception_path)
for p in paths:
@ -250,43 +277,48 @@ def calculate_fid_given_paths(paths, inception_path):
def _init_inception():
global pool3
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
with tf.gfile.FastGFile(os.path.join(
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
# Works with an arbitrary minibatch size.
with tf.Session() as sess:
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
global pool3
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
filename = DATA_URL.split("/")[-1]
filepath = os.path.join(MODEL_DIR, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write(
"\r>> Downloading %s %.1f%%"
% (filename, float(count * block_size) / float(total_size) * 100.0)
)
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
with tf.gfile.FastGFile(
os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
) as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name="")
# Works with an arbitrary minibatch size.
with tf.Session() as sess:
pool3 = sess.graph.get_tensor_by_name("pool_3:0")
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
if pool3 is None:
_init_inception()
_init_inception()