mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-08-17 10:40:13 -04:00
chores: refactor for the new ai research, add linter, gh action, etc (#27)
This commit is contained in:
parent
fb4ab80dc3
commit
d5467e559f
40 changed files with 5177 additions and 2476 deletions
214
EBMs/fid.py
214
EBMs/fid.py
|
@ -1,5 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
|
||||
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs.
|
||||
|
||||
The FID metric calculates the distance between two distributions of images.
|
||||
Typically, we have summary statistics (mean & covariance matrix) of one
|
||||
|
@ -14,28 +14,33 @@ the pool_3 layer of the inception net for generated samples and real world
|
|||
samples respectivly.
|
||||
|
||||
See --help to see further details.
|
||||
'''
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import gzip, pickle
|
||||
import tensorflow as tf
|
||||
from scipy.misc import imread
|
||||
from scipy import linalg
|
||||
import pathlib
|
||||
import urllib
|
||||
import tarfile
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
MODEL_DIR = '/tmp/imagenet'
|
||||
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from scipy import linalg
|
||||
from scipy.misc import imread
|
||||
|
||||
MODEL_DIR = "/tmp/imagenet"
|
||||
DATA_URL = (
|
||||
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
|
||||
)
|
||||
pool3 = None
|
||||
|
||||
|
||||
class InvalidFIDException(Exception):
|
||||
pass
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
def get_fid_score(images, images_gt):
|
||||
images = np.stack(images, 0)
|
||||
images_gt = np.stack(images_gt, 0)
|
||||
|
@ -52,34 +57,38 @@ def get_fid_score(images, images_gt):
|
|||
def create_inception_graph(pth):
|
||||
"""Creates a graph from saved GraphDef file."""
|
||||
# Creates graph from saved graph_def.pb.
|
||||
with tf.gfile.FastGFile( pth, 'rb') as f:
|
||||
with tf.gfile.FastGFile(pth, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString( f.read())
|
||||
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
|
||||
#-------------------------------------------------------------------------------
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name="FID_Inception_Net")
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# code for handling inception net derived from
|
||||
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
|
||||
def _get_inception_layer(sess):
|
||||
"""Prepares inception net for batched usage and returns pool_3 layer. """
|
||||
layername = 'FID_Inception_Net/pool_3:0'
|
||||
"""Prepares inception net for batched usage and returns pool_3 layer."""
|
||||
layername = "FID_Inception_Net/pool_3:0"
|
||||
pool3 = sess.graph.get_tensor_by_name(layername)
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
||||
return pool3
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_activations(images, sess, batch_size=50, verbose=False):
|
||||
|
@ -100,23 +109,27 @@ def get_activations(images, sess, batch_size=50, verbose=False):
|
|||
# inception_layer = _get_inception_layer(sess)
|
||||
d0 = images.shape[0]
|
||||
if batch_size > d0:
|
||||
print("warning: batch size is bigger than the data size. setting batch size to data size")
|
||||
print(
|
||||
"warning: batch size is bigger than the data size. setting batch size to data size"
|
||||
)
|
||||
batch_size = d0
|
||||
n_batches = d0//batch_size
|
||||
n_used_imgs = n_batches*batch_size
|
||||
pred_arr = np.empty((n_used_imgs,2048))
|
||||
n_batches = d0 // batch_size
|
||||
n_used_imgs = n_batches * batch_size
|
||||
pred_arr = np.empty((n_used_imgs, 2048))
|
||||
for i in range(n_batches):
|
||||
if verbose:
|
||||
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
|
||||
start = i*batch_size
|
||||
print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True)
|
||||
start = i * batch_size
|
||||
end = start + batch_size
|
||||
batch = images[start:end]
|
||||
pred = sess.run(pool3, {'ExpandDims:0': batch})
|
||||
pred_arr[start:end] = pred.reshape(batch_size,-1)
|
||||
pred = sess.run(pool3, {"ExpandDims:0": batch})
|
||||
pred_arr[start:end] = pred.reshape(batch_size, -1)
|
||||
if verbose:
|
||||
print(" done")
|
||||
return pred_arr
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
||||
|
@ -147,15 +160,22 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
|||
sigma1 = np.atleast_2d(sigma1)
|
||||
sigma2 = np.atleast_2d(sigma2)
|
||||
|
||||
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
|
||||
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
|
||||
assert (
|
||||
mu1.shape == mu2.shape
|
||||
), "Training and test mean vectors have different lengths"
|
||||
assert (
|
||||
sigma1.shape == sigma2.shape
|
||||
), "Training and test covariances have different dimensions"
|
||||
|
||||
diff = mu1 - mu2
|
||||
|
||||
# product might be almost singular
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
|
||||
msg = (
|
||||
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
||||
% eps
|
||||
)
|
||||
warnings.warn(msg)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
||||
|
@ -170,7 +190,9 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
|||
tr_covmean = np.trace(covmean)
|
||||
|
||||
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
|
||||
|
@ -193,47 +215,52 @@ def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
|
|||
mu = np.mean(act, axis=0)
|
||||
sigma = np.cov(act, rowvar=False)
|
||||
return mu, sigma
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# The following functions aren't needed for calculating the FID
|
||||
# they're just here to make this module work as a stand-alone script
|
||||
# for calculating FID scores
|
||||
#-------------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------------
|
||||
def check_or_download_inception(inception_path):
|
||||
''' Checks if the path to the inception file is valid, or downloads
|
||||
the file if it is not present. '''
|
||||
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
"""Checks if the path to the inception file is valid, or downloads
|
||||
the file if it is not present."""
|
||||
INCEPTION_URL = (
|
||||
"http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
|
||||
)
|
||||
if inception_path is None:
|
||||
inception_path = '/tmp'
|
||||
inception_path = "/tmp"
|
||||
inception_path = pathlib.Path(inception_path)
|
||||
model_file = inception_path / 'classify_image_graph_def.pb'
|
||||
model_file = inception_path / "classify_image_graph_def.pb"
|
||||
if not model_file.exists():
|
||||
print("Downloading Inception model")
|
||||
from urllib import request
|
||||
import tarfile
|
||||
from urllib import request
|
||||
|
||||
fn, _ = request.urlretrieve(INCEPTION_URL)
|
||||
with tarfile.open(fn, mode='r') as f:
|
||||
f.extract('classify_image_graph_def.pb', str(model_file.parent))
|
||||
with tarfile.open(fn, mode="r") as f:
|
||||
f.extract("classify_image_graph_def.pb", str(model_file.parent))
|
||||
return str(model_file)
|
||||
|
||||
|
||||
def _handle_path(path, sess):
|
||||
if path.endswith('.npz'):
|
||||
if path.endswith(".npz"):
|
||||
f = np.load(path)
|
||||
m, s = f['mu'][:], f['sigma'][:]
|
||||
m, s = f["mu"][:], f["sigma"][:]
|
||||
f.close()
|
||||
else:
|
||||
path = pathlib.Path(path)
|
||||
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
|
||||
files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
|
||||
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
|
||||
m, s = calculate_activation_statistics(x, sess)
|
||||
return m, s
|
||||
|
||||
|
||||
def calculate_fid_given_paths(paths, inception_path):
|
||||
''' Calculates the FID of two paths. '''
|
||||
"""Calculates the FID of two paths."""
|
||||
inception_path = check_or_download_inception(inception_path)
|
||||
|
||||
for p in paths:
|
||||
|
@ -250,43 +277,48 @@ def calculate_fid_given_paths(paths, inception_path):
|
|||
|
||||
|
||||
def _init_inception():
|
||||
global pool3
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split('/')[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
||||
filename, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(os.path.join(
|
||||
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name='')
|
||||
# Works with an arbitrary minibatch size.
|
||||
with tf.Session() as sess:
|
||||
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
||||
global pool3
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
filename = DATA_URL.split("/")[-1]
|
||||
filepath = os.path.join(MODEL_DIR, filename)
|
||||
if not os.path.exists(filepath):
|
||||
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write(
|
||||
"\r>> Downloading %s %.1f%%"
|
||||
% (filename, float(count * block_size) / float(total_size) * 100.0)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
|
||||
tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
|
||||
with tf.gfile.FastGFile(
|
||||
os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
|
||||
) as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name="")
|
||||
# Works with an arbitrary minibatch size.
|
||||
with tf.Session() as sess:
|
||||
pool3 = sess.graph.get_tensor_by_name("pool_3:0")
|
||||
ops = pool3.graph.get_operations()
|
||||
for op_idx, op in enumerate(ops):
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims != []:
|
||||
shape = [s.value for s in shape]
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
||||
|
||||
|
||||
if pool3 is None:
|
||||
_init_inception()
|
||||
_init_inception()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue