2024-11-17 17:45:23 -08:00

1108 lines
29 KiB
Python

""" Utility functions. """
import numpy as np
import os
import random
import tensorflow as tf
import warnings
from tensorflow.contrib.layers.python import layers as tf_layers
from tensorflow.python.platform import flags
from tensorflow.contrib.framework import sort
FLAGS = flags.FLAGS
flags.DEFINE_integer('spec_iter', 1, 'Number of iterations to normalize spectrum of matrix')
flags.DEFINE_float('spec_norm_val', 1.0, 'Desired norm of matrices')
flags.DEFINE_bool('downsample', False, 'Wheter to do average pool downsampling')
flags.DEFINE_bool('spec_eval', False, 'Set to true to prevent spectral updates')
def get_median(v):
v = tf.reshape(v, [-1])
m = tf.shape(v)[0] // 2
return tf.nn.top_k(v, m)[m - 1]
def set_seed(seed):
import torch
import numpy
import random
torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)
tf.set_random_seed(seed)
def swish(inp):
return inp * tf.nn.sigmoid(inp)
class ReplayBuffer(object):
def __init__(self, size):
"""Create Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
"""
self._storage = []
self._maxsize = size
self._next_idx = 0
def __len__(self):
return len(self._storage)
def add(self, ims):
batch_size = ims.shape[0]
if self._next_idx >= len(self._storage):
self._storage.extend(list(ims))
else:
if batch_size + self._next_idx < self._maxsize:
self._storage[self._next_idx:self._next_idx +
batch_size] = list(ims)
else:
split_idx = self._maxsize - self._next_idx
self._storage[self._next_idx:] = list(ims)[:split_idx]
self._storage[:batch_size - split_idx] = list(ims)[split_idx:]
self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
def _encode_sample(self, idxes):
ims = []
for i in idxes:
ims.append(self._storage[i])
return np.array(ims)
def sample(self, batch_size):
"""Sample a batch of experiences.
Parameters
----------
batch_size: int
How many transitions to sample.
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
"""
idxes = [random.randint(0, len(self._storage) - 1)
for _ in range(batch_size)]
return self._encode_sample(idxes)
def get_weight(
name,
shape,
gain=np.sqrt(2),
use_wscale=False,
fan_in=None,
spec_norm=False,
zero=False,
fc=False):
if fan_in is None:
fan_in = np.prod(shape[:-1])
std = gain / np.sqrt(fan_in) # He init
if use_wscale:
wscale = tf.constant(np.float32(std), name=name + 'wscale')
var = tf.get_variable(
name + 'weight',
shape=shape,
initializer=tf.initializers.random_normal()) * wscale
elif spec_norm:
if zero:
var = tf.get_variable(
shape=shape,
name=name + 'weight',
initializer=tf.initializers.random_normal(
stddev=1e-10))
var = spectral_normed_weight(var, name, lower_bound=True, fc=fc)
else:
var = tf.get_variable(
name + 'weight',
shape=shape,
initializer=tf.initializers.random_normal())
var = spectral_normed_weight(var, name, fc=fc)
else:
if zero:
var = tf.get_variable(
name + 'weight',
shape=shape,
initializer=tf.initializers.zero())
else:
var = tf.get_variable(
name + 'weight',
shape=shape,
initializer=tf.contrib.layers.xavier_initializer(
dtype=tf.float32))
return var
def pixel_norm(x, epsilon=1e-8):
with tf.variable_scope('PixelNorm'):
return x * tf.rsqrt(tf.reduce_mean(tf.square(x),
axis=[1, 2], keepdims=True) + epsilon)
# helper
def get_images(paths, labels, nb_samples=None, shuffle=True):
if nb_samples is not None:
def sampler(x): return random.sample(x, nb_samples)
else:
def sampler(x): return x
images = [(i, os.path.join(path, image))
for i, path in zip(labels, paths)
for image in sampler(os.listdir(path))]
if shuffle:
random.shuffle(images)
return images
def optimistic_restore(session, save_file, v_prefix=None):
reader = tf.train.NewCheckpointReader(save_file)
saved_shapes = reader.get_variable_to_shape_map()
var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES) if var.name.split(':')[0] in saved_shapes])
restore_vars = []
with tf.variable_scope('', reuse=True):
for var_name, saved_var_name in var_names:
try:
curr_var = tf.get_variable(saved_var_name)
except Exception as e:
print(e)
continue
var_shape = curr_var.get_shape().as_list()
if var_shape == saved_shapes[saved_var_name]:
restore_vars.append(curr_var)
else:
print(var_name)
print(var_shape, saved_shapes[saved_var_name])
saver = tf.train.Saver(restore_vars)
saver.restore(session, save_file)
def optimistic_remap_restore(session, save_file, v_prefix):
reader = tf.train.NewCheckpointReader(save_file)
saved_shapes = reader.get_variable_to_shape_map()
vars_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES,
scope='context_{}'.format(v_prefix))
var_names = sorted([(var.name.split(':')[0], var) for var in vars_list if (
(var.name.split(':')[0]).replace('context_{}'.format(v_prefix), 'context_0') in saved_shapes)])
restore_vars = []
v_map = {}
with tf.variable_scope('', reuse=True):
for saved_var_name, curr_var in var_names:
var_shape = curr_var.get_shape().as_list()
saved_var_name = saved_var_name.replace(
'context_{}'.format(v_prefix), 'context_0')
if var_shape == saved_shapes[saved_var_name]:
v_map[saved_var_name] = curr_var
else:
print(saved_var_name)
print(var_shape, saved_shapes[saved_var_name])
saver = tf.train.Saver(v_map)
saver.restore(session, save_file)
def remap_restore(session, save_file, i):
reader = tf.train.NewCheckpointReader(save_file)
saved_shapes = reader.get_variable_to_shape_map()
var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
if var.name.split(':')[0] in saved_shapes])
restore_vars = []
with tf.variable_scope('', reuse=True):
for var_name, saved_var_name in var_names:
try:
curr_var = tf.get_variable(saved_var_name)
except Exception as e:
print(e)
continue
var_shape = curr_var.get_shape().as_list()
if var_shape == saved_shapes[saved_var_name]:
restore_vars.append(curr_var)
print(restore_vars)
saver = tf.train.Saver(restore_vars)
saver.restore(session, save_file)
# Network weight initializers
def init_conv_weight(
weights,
scope,
k,
c_in,
c_out,
spec_norm=True,
zero=False,
scale=1.0,
classes=1):
if spec_norm:
spec_norm = FLAGS.spec_norm
conv_weights = {}
with tf.variable_scope(scope):
if zero:
conv_weights['c'] = get_weight(
'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True)
else:
conv_weights['c'] = get_weight(
'c', [k, k, c_in, c_out], spec_norm=spec_norm)
conv_weights['b'] = tf.get_variable(
shape=[c_out], name='b', initializer=tf.initializers.zeros())
if classes != 1:
conv_weights['g'] = tf.get_variable(
shape=[
classes,
c_out],
name='g',
initializer=tf.initializers.ones())
conv_weights['gb'] = tf.get_variable(
shape=[
classes,
c_in],
name='gb',
initializer=tf.initializers.zeros())
else:
conv_weights['g'] = tf.get_variable(
shape=[c_out], name='g', initializer=tf.initializers.ones())
conv_weights['gb'] = tf.get_variable(
shape=[c_in], name='gb', initializer=tf.initializers.zeros())
conv_weights['cb'] = tf.get_variable(
shape=[c_in], name='cb', initializer=tf.initializers.zeros())
weights[scope] = conv_weights
def init_convt_weight(
weights,
scope,
k,
c_in,
c_out,
spec_norm=True,
zero=False,
scale=1.0,
classes=1):
if spec_norm:
spec_norm = FLAGS.spec_norm
conv_weights = {}
with tf.variable_scope(scope):
if zero:
conv_weights['c'] = get_weight(
'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True)
else:
conv_weights['c'] = get_weight(
'c', [k, k, c_in, c_out], spec_norm=spec_norm)
conv_weights['b'] = tf.get_variable(
shape=[c_in], name='b', initializer=tf.initializers.zeros())
if classes != 1:
conv_weights['g'] = tf.get_variable(
shape=[
classes,
c_in],
name='g',
initializer=tf.initializers.ones())
conv_weights['gb'] = tf.get_variable(
shape=[
classes,
c_out],
name='gb',
initializer=tf.initializers.zeros())
else:
conv_weights['g'] = tf.get_variable(
shape=[c_in], name='g', initializer=tf.initializers.ones())
conv_weights['gb'] = tf.get_variable(
shape=[c_out], name='gb', initializer=tf.initializers.zeros())
conv_weights['cb'] = tf.get_variable(
shape=[c_in], name='cb', initializer=tf.initializers.zeros())
weights[scope] = conv_weights
def init_attention_weight(
weights,
scope,
c_in,
k,
trainable_gamma=True,
spec_norm=True):
if spec_norm:
spec_norm = FLAGS.spec_norm
atten_weights = {}
with tf.variable_scope(scope):
atten_weights['q'] = get_weight(
'atten_q', [1, 1, c_in, k], spec_norm=spec_norm)
atten_weights['q_b'] = tf.get_variable(
shape=[k], name='atten_q_b1', initializer=tf.initializers.zeros())
atten_weights['k'] = get_weight(
'atten_k', [1, 1, c_in, k], spec_norm=spec_norm)
atten_weights['k_b'] = tf.get_variable(
shape=[k], name='atten_k_b1', initializer=tf.initializers.zeros())
atten_weights['v'] = get_weight(
'atten_v', [1, 1, c_in, c_in], spec_norm=spec_norm)
atten_weights['v_b'] = tf.get_variable(
shape=[c_in], name='atten_v_b1', initializer=tf.initializers.zeros())
atten_weights['gamma'] = tf.get_variable(
shape=[1], name='gamma', initializer=tf.initializers.zeros())
weights[scope] = atten_weights
def init_fc_weight(weights, scope, c_in, c_out, spec_norm=True):
fc_weights = {}
if spec_norm:
spec_norm = FLAGS.spec_norm
with tf.variable_scope(scope):
fc_weights['w'] = get_weight(
'w', [c_in, c_out], spec_norm=spec_norm, fc=True)
fc_weights['b'] = tf.get_variable(
shape=[c_out], name='b', initializer=tf.initializers.zeros())
weights[scope] = fc_weights
def init_res_weight(
weights,
scope,
k,
c_in,
c_out,
hidden_dim=None,
spec_norm=True,
res_scale=1.0,
classes=1):
if not hidden_dim:
hidden_dim = c_in
if spec_norm:
spec_norm = FLAGS.spec_norm
init_conv_weight(
weights,
scope +
'_res_c1',
k,
c_in,
c_out,
spec_norm=spec_norm,
scale=res_scale,
classes=classes)
init_conv_weight(
weights,
scope + '_res_c2',
k,
c_out,
c_out,
spec_norm=spec_norm,
zero=True,
scale=res_scale,
classes=classes)
if c_in != c_out:
init_conv_weight(
weights,
scope +
'_res_adaptive',
k,
c_in,
c_out,
spec_norm=spec_norm,
scale=res_scale,
classes=classes)
# Network forward helpers
def smart_conv_block(inp, weights, reuse, scope, use_stride=True, **kwargs):
weights = weights[scope]
return conv_block(
inp,
weights['c'],
weights['b'],
reuse,
scope,
scale=weights['g'],
bias=weights['gb'],
class_bias=weights['cb'],
use_stride=use_stride,
**kwargs)
def smart_convt_block(
inp,
weights,
reuse,
scope,
output_dim,
upsample=True,
label=None):
weights = weights[scope]
cweight = weights['c']
bweight = weights['b']
scale = weights['g']
bias = weights['gb']
class_bias = weights['cb']
if upsample:
stride = [1, 2, 2, 1]
else:
stride = [1, 1, 1, 1]
if label is not None:
bias_batch = tf.matmul(label, bias)
batch = tf.shape(bias_batch)[0]
dim = tf.shape(bias_batch)[1]
bias = tf.reshape(bias_batch, (batch, 1, 1, dim))
inp = inp + bias
shape = cweight.get_shape()
conv_output = tf.nn.conv2d_transpose(inp,
cweight,
[tf.shape(inp)[0],
output_dim,
output_dim,
cweight.get_shape().as_list()[-2]],
stride,
'SAME')
if label is not None:
scale_batch = tf.matmul(label, scale) + class_bias
batch = tf.shape(scale_batch)[0]
dim = tf.shape(scale_batch)[1]
scale = tf.reshape(scale_batch, (batch, 1, 1, dim))
conv_output = conv_output * scale
conv_output = tf.nn.leaky_relu(conv_output)
return conv_output
def smart_res_block(
inp,
weights,
reuse,
scope,
downsample=True,
adaptive=True,
stop_batch=False,
upsample=False,
label=None,
act=tf.nn.leaky_relu,
dropout=False,
train=False,
**kwargs):
gn1 = weights[scope + '_res_c1']
gn2 = weights[scope + '_res_c2']
c1 = smart_conv_block(
inp,
weights,
reuse,
scope + '_res_c1',
use_stride=False,
activation=None,
extra_bias=True,
label=label,
**kwargs)
if dropout:
c1 = tf.layers.dropout(c1, rate=0.5, training=train)
c1 = act(c1)
c2 = smart_conv_block(
c1,
weights,
reuse,
scope + '_res_c2',
use_stride=False,
activation=None,
use_scale=True,
extra_bias=True,
label=label,
**kwargs)
if adaptive:
c_bypass = smart_conv_block(
inp,
weights,
reuse,
scope +
'_res_adaptive',
use_stride=False,
activation=None,
**kwargs)
else:
c_bypass = inp
res = c2 + c_bypass
if upsample:
res_shape = tf.shape(res)
res_shape_list = res.get_shape()
res = tf.image.resize_nearest_neighbor(
res, [2 * res_shape_list[1], 2 * res_shape_list[2]])
elif downsample:
res = tf.nn.avg_pool(res, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
res = act(res)
return res
def smart_res_block_optim(inp, weights, reuse, scope, **kwargs):
c1 = smart_conv_block(
inp,
weights,
reuse,
scope + '_res_c1',
use_stride=False,
activation=None,
**kwargs)
c1 = tf.nn.leaky_relu(c1)
c2 = smart_conv_block(
c1,
weights,
reuse,
scope + '_res_c2',
use_stride=False,
activation=None,
**kwargs)
inp = tf.nn.avg_pool(inp, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
c_bypass = smart_conv_block(
inp,
weights,
reuse,
scope +
'_res_adaptive',
use_stride=False,
activation=None,
**kwargs)
c2 = tf.nn.avg_pool(c2, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
res = c2 + c_bypass
return c2
def groupsort(k=4):
def sortact(inp):
old_shape = tf.shape(inp)
inp = sort(tf.reshape(inp, (-1, 4)))
inp = tf.reshape(inp, old_shape)
return inp
return sortact
def smart_atten_block(inp, weights, reuse, scope, **kwargs):
w = weights[scope]
return attention(
inp,
w['q'],
w['q_b'],
w['k'],
w['k_b'],
w['v'],
w['v_b'],
w['gamma'],
reuse,
scope,
**kwargs)
def smart_fc_block(inp, weights, reuse, scope, use_bias=True):
weights = weights[scope]
output = tf.matmul(inp, weights['w'])
if use_bias:
output = output + weights['b']
return output
# Network helpers
def conv_block(
inp,
cweight,
bweight,
reuse,
scope,
use_stride=True,
activation=tf.nn.leaky_relu,
pn=False,
bn=False,
gn=False,
ln=False,
scale=None,
bias=None,
class_bias=None,
use_bias=False,
downsample=False,
stop_batch=False,
use_scale=False,
extra_bias=False,
average=False,
label=None):
""" Perform, conv, batch norm, nonlinearity, and max pool """
stride, no_stride = [1, 2, 2, 1], [1, 1, 1, 1]
_, h, w, _ = inp.get_shape()
if FLAGS.downsample:
stride = no_stride
if not use_bias:
bweight = 0
if extra_bias:
if label is not None:
if len(bias.get_shape()) == 1:
bias = tf.reshape(bias, (1, -1))
bias_batch = tf.matmul(label, bias)
batch = tf.shape(bias_batch)[0]
dim = tf.shape(bias_batch)[1]
bias = tf.reshape(bias_batch, (batch, 1, 1, dim))
inp = inp + bias
if not use_stride:
conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME')
else:
conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME')
if use_scale:
if label is not None:
if len(scale.get_shape()) == 1:
scale = tf.reshape(scale, (1, -1))
scale_batch = tf.matmul(label, scale) + class_bias
batch = tf.shape(scale_batch)[0]
dim = tf.shape(scale_batch)[1]
scale = tf.reshape(scale_batch, (batch, 1, 1, dim))
conv_output = conv_output * scale
if use_bias:
conv_output = conv_output + bweight
if activation is not None:
conv_output = activation(conv_output)
if bn:
conv_output = batch_norm(conv_output, scale, bias)
if pn:
conv_output = pixel_norm(conv_output)
if gn:
conv_output = group_norm(
conv_output, scale, bias, stop_batch=stop_batch)
if ln:
conv_output = layer_norm(conv_output, scale, bias)
if FLAGS.downsample and use_stride:
conv_output = tf.layers.average_pooling2d(conv_output, (2, 2), 2)
return conv_output
def conv_block_1d(
inp,
cweight,
bweight,
reuse,
scope,
activation=tf.nn.leaky_relu):
""" Perform, conv, batch norm, nonlinearity, and max pool """
stride = 1
conv_output = tf.nn.conv1d(inp, cweight, stride, 'SAME') + bweight
if activation is not None:
conv_output = activation(conv_output)
return conv_output
def conv_block_3d(
inp,
cweight,
bweight,
reuse,
scope,
use_stride=True,
activation=tf.nn.leaky_relu,
pn=False,
bn=False,
gn=False,
ln=False,
scale=None,
bias=None,
use_bias=False):
""" Perform, conv, batch norm, nonlinearity, and max pool """
stride, no_stride = [1, 1, 2, 2, 1], [1, 1, 1, 1, 1]
_, d, h, w, _ = inp.get_shape()
if not use_bias:
bweight = 0
if not use_stride:
conv_output = tf.nn.conv3d(inp, cweight, no_stride, 'SAME') + bweight
else:
conv_output = tf.nn.conv3d(inp, cweight, stride, 'SAME') + bweight
if activation is not None:
conv_output = activation(conv_output, alpha=0.1)
if bn:
conv_output = batch_norm(conv_output, scale, bias)
if pn:
conv_output = pixel_norm(conv_output)
if gn:
conv_output = group_norm(conv_output, scale, bias)
if ln:
conv_output = layer_norm(conv_output, scale, bias)
if FLAGS.downsample and use_stride:
conv_output = tf.layers.average_pooling2d(conv_output, (2, 2), 2)
return conv_output
def group_norm(inp, scale, bias, g=32, eps=1e-6, stop_batch=False):
"""Applies group normalization assuming nhwc format"""
n, h, w, c = inp.shape
inp = tf.reshape(inp, (tf.shape(inp)[0], h, w, c // g, g))
mean, var = tf.nn.moments(inp, [1, 2, 4], keep_dims=True)
gain = tf.rsqrt(var + eps)
# if stop_batch:
# gain = tf.stop_gradient(gain)
output = gain * (inp - mean)
output = tf.reshape(output, (tf.shape(inp)[0], h, w, c))
if scale is not None:
output = output * scale
if bias is not None:
output = output + bias
return output
def layer_norm(inp, scale, bias, eps=1e-6):
"""Applies group normalization assuming nhwc format"""
n, h, w, c = inp.shape
mean, var = tf.nn.moments(inp, [1, 2, 3], keep_dims=True)
gain = tf.rsqrt(var + eps)
output = gain * (inp - mean)
if scale is not None:
output = output * scale
if bias is not None:
output = output + bias
return output
def conv_cond_concat(x, y):
"""Concatenate conditioning vector on feature map axis."""
x_shapes = tf.shape(x)
y_shapes = tf.shape(y)
return tf.concat(
[x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]]) / 10.], 3)
def attention(
inp,
q,
q_b,
k,
k_b,
v,
v_b,
gamma,
reuse,
scope,
stop_at_grad=False,
seperate=False,
scale=False,
train=False,
dropout=0.0):
conv_q = conv_block(
inp,
q,
q_b,
reuse=reuse,
scope=scope,
use_stride=False,
activation=None,
use_bias=True,
pn=False,
bn=False,
gn=False)
conv_k = conv_block(
inp,
k,
k_b,
reuse=reuse,
scope=scope,
use_stride=False,
activation=None,
use_bias=True,
pn=False,
bn=False,
gn=False)
conv_v = conv_block(
inp,
v,
v_b,
reuse=reuse,
scope=scope,
use_stride=False,
pn=False,
bn=False,
gn=False)
c_num = float(conv_q.get_shape().as_list()[-1])
s = tf.matmul(hw_flatten(conv_q), hw_flatten(conv_k), transpose_b=True)
if scale:
s = s / (c_num) ** 0.5
if train:
s = tf.nn.dropout(s, 0.9)
beta = tf.nn.softmax(s, axis=-1)
o = tf.matmul(beta, hw_flatten(conv_v))
o = tf.reshape(o, shape=tf.shape(inp))
inp = inp + gamma * o
if not seperate:
return inp
else:
return gamma * o
def attention_2d(
inp,
q,
q_b,
k,
k_b,
v,
v_b,
reuse,
scope,
stop_at_grad=False,
seperate=False,
scale=False):
inp_shape = tf.shape(inp)
inp_compact = tf.reshape(
inp,
(inp_shape[0] *
FLAGS.input_objects *
inp_shape[1],
inp.shape[3]))
f_q = tf.matmul(inp_compact, q) + q_b
f_k = tf.matmul(inp_compact, k) + k_b
f_v = tf.nn.leaky_relu(tf.matmul(inp_compact, v) + v_b)
f_q = tf.reshape(f_q,
(inp_shape[0],
inp_shape[1],
inp_shape[2],
tf.shape(f_q)[-1]))
f_k = tf.reshape(f_k,
(inp_shape[0],
inp_shape[1],
inp_shape[2],
tf.shape(f_k)[-1]))
f_v = tf.reshape(
f_v,
(inp_shape[0],
inp_shape[1],
inp_shape[2],
inp_shape[3]))
s = tf.matmul(f_k, f_q, transpose_b=True)
c_num = (32**0.5)
if scale:
s = s / c_num
beta = tf.nn.softmax(s, axis=-1)
o = tf.reshape(tf.matmul(beta, f_v), inp_shape) + inp
return o
def hw_flatten(x):
shape = tf.shape(x)
return tf.reshape(x, [tf.shape(x)[0], -1, shape[-1]])
def batch_norm(inp, scale, bias, eps=0.01):
mean, var = tf.nn.moments(inp, [0])
output = tf.nn.batch_normalization(inp, mean, var, bias, scale, eps)
return output
def normalize(inp, activation, reuse, scope):
if FLAGS.norm == 'batch_norm':
return tf_layers.batch_norm(
inp,
activation_fn=activation,
reuse=reuse,
scope=scope)
elif FLAGS.norm == 'layer_norm':
return tf_layers.layer_norm(
inp,
activation_fn=activation,
reuse=reuse,
scope=scope)
elif FLAGS.norm == 'None':
if activation is not None:
return activation(inp)
else:
return inp
# Loss functions
def mse(pred, label):
pred = tf.reshape(pred, [-1])
label = tf.reshape(label, [-1])
return tf.reduce_mean(tf.square(pred - label))
NO_OPS = 'NO_OPS'
def _l2normalize(v, eps=1e-12):
return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
def spectral_normed_weight(w, name, lower_bound=False, iteration=1, fc=False):
if fc:
iteration = 2
w_shape = w.shape.as_list()
w = tf.reshape(w, [-1, w_shape[-1]])
iteration = FLAGS.spec_iter
sigma_new = FLAGS.spec_norm_val
u = tf.get_variable(name + "_u",
[1,
w_shape[-1]],
initializer=tf.random_normal_initializer(),
trainable=False)
u_hat = u
v_hat = None
for i in range(iteration):
"""
power iteration
Usually iteration = 1 will be enough
"""
v_ = tf.matmul(u_hat, tf.transpose(w))
v_hat = tf.nn.l2_normalize(v_)
u_ = tf.matmul(v_hat, w)
u_hat = tf.nn.l2_normalize(u_)
u_hat = tf.stop_gradient(u_hat)
v_hat = tf.stop_gradient(v_hat)
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
if FLAGS.spec_eval:
dep = []
else:
dep = [u.assign(u_hat)]
with tf.control_dependencies(dep):
if lower_bound:
sigma = sigma + 1e-6
w_norm = w / sigma * tf.minimum(sigma, 1) * sigma_new
else:
w_norm = w / sigma * sigma_new
w_norm = tf.reshape(w_norm, w_shape)
return w_norm
def average_gradients(tower_grads):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads = []
for g, v in grad_and_vars:
if g is not None:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over
# below.
grads.append(expanded_g)
else:
print(g, v)
# Average over the 'tower' dimension.
grad = tf.concat(axis=0, values=grads)
grad = tf.reduce_mean(grad, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads