mirror of
https://github.com/autistic-symposium/ml-ai-agents-py.git
synced 2025-04-25 10:09:09 -04:00
293 lines
11 KiB
Python
293 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
|
|
|
|
The FID metric calculates the distance between two distributions of images.
|
|
Typically, we have summary statistics (mean & covariance matrix) of one
|
|
of these distributions, while the 2nd distribution is given by a GAN.
|
|
|
|
When run as a stand-alone program, it compares the distribution of
|
|
images that are stored as PNG/JPEG at a specified location with a
|
|
distribution given by summary statistics (in pickle format).
|
|
|
|
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
|
the pool_3 layer of the inception net for generated samples and real world
|
|
samples respectivly.
|
|
|
|
See --help to see further details.
|
|
'''
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
import numpy as np
|
|
import os
|
|
import gzip, pickle
|
|
import tensorflow as tf
|
|
from scipy.misc import imread
|
|
from scipy import linalg
|
|
import pathlib
|
|
import urllib
|
|
import tarfile
|
|
import warnings
|
|
|
|
MODEL_DIR = '/tmp/imagenet'
|
|
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
|
pool3 = None
|
|
|
|
class InvalidFIDException(Exception):
|
|
pass
|
|
|
|
#-------------------------------------------------------------------------------
|
|
def get_fid_score(images, images_gt):
|
|
images = np.stack(images, 0)
|
|
images_gt = np.stack(images_gt, 0)
|
|
|
|
with tf.Session() as sess:
|
|
m1, s1 = calculate_activation_statistics(images, sess)
|
|
m2, s2 = calculate_activation_statistics(images_gt, sess)
|
|
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
|
|
|
print("Obtained fid value of {}".format(fid_value))
|
|
return fid_value
|
|
|
|
|
|
def create_inception_graph(pth):
|
|
"""Creates a graph from saved GraphDef file."""
|
|
# Creates graph from saved graph_def.pb.
|
|
with tf.gfile.FastGFile( pth, 'rb') as f:
|
|
graph_def = tf.GraphDef()
|
|
graph_def.ParseFromString( f.read())
|
|
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
|
|
#-------------------------------------------------------------------------------
|
|
|
|
|
|
# code for handling inception net derived from
|
|
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
|
|
def _get_inception_layer(sess):
|
|
"""Prepares inception net for batched usage and returns pool_3 layer. """
|
|
layername = 'FID_Inception_Net/pool_3:0'
|
|
pool3 = sess.graph.get_tensor_by_name(layername)
|
|
ops = pool3.graph.get_operations()
|
|
for op_idx, op in enumerate(ops):
|
|
for o in op.outputs:
|
|
shape = o.get_shape()
|
|
if shape._dims != []:
|
|
shape = [s.value for s in shape]
|
|
new_shape = []
|
|
for j, s in enumerate(shape):
|
|
if s == 1 and j == 0:
|
|
new_shape.append(None)
|
|
else:
|
|
new_shape.append(s)
|
|
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
|
return pool3
|
|
#-------------------------------------------------------------------------------
|
|
|
|
|
|
def get_activations(images, sess, batch_size=50, verbose=False):
|
|
"""Calculates the activations of the pool_3 layer for all images.
|
|
|
|
Params:
|
|
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
|
|
must lie between 0 and 256.
|
|
-- sess : current session
|
|
-- batch_size : the images numpy array is split into batches with batch size
|
|
batch_size. A reasonable batch size depends on the disposable hardware.
|
|
-- verbose : If set to True and parameter out_step is given, the number of calculated
|
|
batches is reported.
|
|
Returns:
|
|
-- A numpy array of dimension (num images, 2048) that contains the
|
|
activations of the given tensor when feeding inception with the query tensor.
|
|
"""
|
|
# inception_layer = _get_inception_layer(sess)
|
|
d0 = images.shape[0]
|
|
if batch_size > d0:
|
|
print("warning: batch size is bigger than the data size. setting batch size to data size")
|
|
batch_size = d0
|
|
n_batches = d0//batch_size
|
|
n_used_imgs = n_batches*batch_size
|
|
pred_arr = np.empty((n_used_imgs,2048))
|
|
for i in range(n_batches):
|
|
if verbose:
|
|
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
|
|
start = i*batch_size
|
|
end = start + batch_size
|
|
batch = images[start:end]
|
|
pred = sess.run(pool3, {'ExpandDims:0': batch})
|
|
pred_arr[start:end] = pred.reshape(batch_size,-1)
|
|
if verbose:
|
|
print(" done")
|
|
return pred_arr
|
|
#-------------------------------------------------------------------------------
|
|
|
|
|
|
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
|
"""Numpy implementation of the Frechet Distance.
|
|
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
|
and X_2 ~ N(mu_2, C_2) is
|
|
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
|
Stable version by Dougal J. Sutherland.
|
|
|
|
Params:
|
|
-- mu1 : Numpy array containing the activations of the pool_3 layer of the
|
|
inception net ( like returned by the function 'get_predictions')
|
|
for generated samples.
|
|
-- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
|
|
on an representive data set.
|
|
-- sigma1: The covariance matrix over activations of the pool_3 layer for
|
|
generated samples.
|
|
-- sigma2: The covariance matrix over activations of the pool_3 layer,
|
|
precalcualted on an representive data set.
|
|
|
|
Returns:
|
|
-- : The Frechet Distance.
|
|
"""
|
|
|
|
mu1 = np.atleast_1d(mu1)
|
|
mu2 = np.atleast_1d(mu2)
|
|
|
|
sigma1 = np.atleast_2d(sigma1)
|
|
sigma2 = np.atleast_2d(sigma2)
|
|
|
|
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
|
|
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
|
|
|
|
diff = mu1 - mu2
|
|
|
|
# product might be almost singular
|
|
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
|
if not np.isfinite(covmean).all():
|
|
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
|
|
warnings.warn(msg)
|
|
offset = np.eye(sigma1.shape[0]) * eps
|
|
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
|
|
|
# numerical error might give slight imaginary component
|
|
if np.iscomplexobj(covmean):
|
|
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
|
m = np.max(np.abs(covmean.imag))
|
|
raise ValueError("Imaginary component {}".format(m))
|
|
covmean = covmean.real
|
|
|
|
tr_covmean = np.trace(covmean)
|
|
|
|
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
|
#-------------------------------------------------------------------------------
|
|
|
|
|
|
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
|
|
"""Calculation of the statistics used by the FID.
|
|
Params:
|
|
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
|
|
must lie between 0 and 255.
|
|
-- sess : current session
|
|
-- batch_size : the images numpy array is split into batches with batch size
|
|
batch_size. A reasonable batch size depends on the available hardware.
|
|
-- verbose : If set to True and parameter out_step is given, the number of calculated
|
|
batches is reported.
|
|
Returns:
|
|
-- mu : The mean over samples of the activations of the pool_3 layer of
|
|
the incption model.
|
|
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
|
the incption model.
|
|
"""
|
|
act = get_activations(images, sess, batch_size, verbose)
|
|
mu = np.mean(act, axis=0)
|
|
sigma = np.cov(act, rowvar=False)
|
|
return mu, sigma
|
|
#-------------------------------------------------------------------------------
|
|
|
|
|
|
#-------------------------------------------------------------------------------
|
|
# The following functions aren't needed for calculating the FID
|
|
# they're just here to make this module work as a stand-alone script
|
|
# for calculating FID scores
|
|
#-------------------------------------------------------------------------------
|
|
def check_or_download_inception(inception_path):
|
|
''' Checks if the path to the inception file is valid, or downloads
|
|
the file if it is not present. '''
|
|
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
|
if inception_path is None:
|
|
inception_path = '/tmp'
|
|
inception_path = pathlib.Path(inception_path)
|
|
model_file = inception_path / 'classify_image_graph_def.pb'
|
|
if not model_file.exists():
|
|
print("Downloading Inception model")
|
|
from urllib import request
|
|
import tarfile
|
|
fn, _ = request.urlretrieve(INCEPTION_URL)
|
|
with tarfile.open(fn, mode='r') as f:
|
|
f.extract('classify_image_graph_def.pb', str(model_file.parent))
|
|
return str(model_file)
|
|
|
|
|
|
def _handle_path(path, sess):
|
|
if path.endswith('.npz'):
|
|
f = np.load(path)
|
|
m, s = f['mu'][:], f['sigma'][:]
|
|
f.close()
|
|
else:
|
|
path = pathlib.Path(path)
|
|
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
|
|
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
|
|
m, s = calculate_activation_statistics(x, sess)
|
|
return m, s
|
|
|
|
|
|
def calculate_fid_given_paths(paths, inception_path):
|
|
''' Calculates the FID of two paths. '''
|
|
inception_path = check_or_download_inception(inception_path)
|
|
|
|
for p in paths:
|
|
if not os.path.exists(p):
|
|
raise RuntimeError("Invalid path: %s" % p)
|
|
|
|
create_inception_graph(str(inception_path))
|
|
with tf.Session() as sess:
|
|
sess.run(tf.global_variables_initializer())
|
|
m1, s1 = _handle_path(paths[0], sess)
|
|
m2, s2 = _handle_path(paths[1], sess)
|
|
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
|
return fid_value
|
|
|
|
|
|
def _init_inception():
|
|
global pool3
|
|
if not os.path.exists(MODEL_DIR):
|
|
os.makedirs(MODEL_DIR)
|
|
filename = DATA_URL.split('/')[-1]
|
|
filepath = os.path.join(MODEL_DIR, filename)
|
|
if not os.path.exists(filepath):
|
|
def _progress(count, block_size, total_size):
|
|
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
|
filename, float(count * block_size) / float(total_size) * 100.0))
|
|
sys.stdout.flush()
|
|
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
|
print()
|
|
statinfo = os.stat(filepath)
|
|
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
|
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
|
|
with tf.gfile.FastGFile(os.path.join(
|
|
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
|
|
graph_def = tf.GraphDef()
|
|
graph_def.ParseFromString(f.read())
|
|
_ = tf.import_graph_def(graph_def, name='')
|
|
# Works with an arbitrary minibatch size.
|
|
with tf.Session() as sess:
|
|
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
|
|
ops = pool3.graph.get_operations()
|
|
for op_idx, op in enumerate(ops):
|
|
for o in op.outputs:
|
|
shape = o.get_shape()
|
|
if shape._dims != []:
|
|
shape = [s.value for s in shape]
|
|
new_shape = []
|
|
for j, s in enumerate(shape):
|
|
if s == 1 and j == 0:
|
|
new_shape.append(None)
|
|
else:
|
|
new_shape.append(s)
|
|
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
|
|
|
|
|
|
if pool3 is None:
|
|
_init_inception()
|