""" Utility functions. """ import numpy as np import os import random import tensorflow as tf import warnings from tensorflow.contrib.layers.python import layers as tf_layers from tensorflow.python.platform import flags from tensorflow.contrib.framework import sort FLAGS = flags.FLAGS flags.DEFINE_integer('spec_iter', 1, 'Number of iterations to normalize spectrum of matrix') flags.DEFINE_float('spec_norm_val', 1.0, 'Desired norm of matrices') flags.DEFINE_bool('downsample', False, 'Wheter to do average pool downsampling') flags.DEFINE_bool('spec_eval', False, 'Set to true to prevent spectral updates') def get_median(v): v = tf.reshape(v, [-1]) m = tf.shape(v)[0] // 2 return tf.nn.top_k(v, m)[m - 1] def set_seed(seed): import torch import numpy import random torch.manual_seed(seed) numpy.random.seed(seed) random.seed(seed) tf.set_random_seed(seed) def swish(inp): return inp * tf.nn.sigmoid(inp) class ReplayBuffer(object): def __init__(self, size): """Create Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. """ self._storage = [] self._maxsize = size self._next_idx = 0 def __len__(self): return len(self._storage) def add(self, ims): batch_size = ims.shape[0] if self._next_idx >= len(self._storage): self._storage.extend(list(ims)) else: if batch_size + self._next_idx < self._maxsize: self._storage[self._next_idx:self._next_idx + batch_size] = list(ims) else: split_idx = self._maxsize - self._next_idx self._storage[self._next_idx:] = list(ims)[:split_idx] self._storage[:batch_size - split_idx] = list(ims)[split_idx:] self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize def _encode_sample(self, idxes): ims = [] for i in idxes: ims.append(self._storage[i]) return np.array(ims) def sample(self, batch_size): """Sample a batch of experiences. Parameters ---------- batch_size: int How many transitions to sample. Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. """ idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] return self._encode_sample(idxes) def get_weight( name, shape, gain=np.sqrt(2), use_wscale=False, fan_in=None, spec_norm=False, zero=False, fc=False): if fan_in is None: fan_in = np.prod(shape[:-1]) std = gain / np.sqrt(fan_in) # He init if use_wscale: wscale = tf.constant(np.float32(std), name=name + 'wscale') var = tf.get_variable( name + 'weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale elif spec_norm: if zero: var = tf.get_variable( shape=shape, name=name + 'weight', initializer=tf.initializers.random_normal( stddev=1e-10)) var = spectral_normed_weight(var, name, lower_bound=True, fc=fc) else: var = tf.get_variable( name + 'weight', shape=shape, initializer=tf.initializers.random_normal()) var = spectral_normed_weight(var, name, fc=fc) else: if zero: var = tf.get_variable( name + 'weight', shape=shape, initializer=tf.initializers.zero()) else: var = tf.get_variable( name + 'weight', shape=shape, initializer=tf.contrib.layers.xavier_initializer( dtype=tf.float32)) return var def pixel_norm(x, epsilon=1e-8): with tf.variable_scope('PixelNorm'): return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) + epsilon) # helper def get_images(paths, labels, nb_samples=None, shuffle=True): if nb_samples is not None: def sampler(x): return random.sample(x, nb_samples) else: def sampler(x): return x images = [(i, os.path.join(path, image)) for i, path in zip(labels, paths) for image in sampler(os.listdir(path))] if shuffle: random.shuffle(images) return images def optimistic_restore(session, save_file, v_prefix=None): reader = tf.train.NewCheckpointReader(save_file) saved_shapes = reader.get_variable_to_shape_map() var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES) if var.name.split(':')[0] in saved_shapes]) restore_vars = [] with tf.variable_scope('', reuse=True): for var_name, saved_var_name in var_names: try: curr_var = tf.get_variable(saved_var_name) except Exception as e: print(e) continue var_shape = curr_var.get_shape().as_list() if var_shape == saved_shapes[saved_var_name]: restore_vars.append(curr_var) else: print(var_name) print(var_shape, saved_shapes[saved_var_name]) saver = tf.train.Saver(restore_vars) saver.restore(session, save_file) def optimistic_remap_restore(session, save_file, v_prefix): reader = tf.train.NewCheckpointReader(save_file) saved_shapes = reader.get_variable_to_shape_map() vars_list = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(v_prefix)) var_names = sorted([(var.name.split(':')[0], var) for var in vars_list if ( (var.name.split(':')[0]).replace('context_{}'.format(v_prefix), 'context_0') in saved_shapes)]) restore_vars = [] v_map = {} with tf.variable_scope('', reuse=True): for saved_var_name, curr_var in var_names: var_shape = curr_var.get_shape().as_list() saved_var_name = saved_var_name.replace( 'context_{}'.format(v_prefix), 'context_0') if var_shape == saved_shapes[saved_var_name]: v_map[saved_var_name] = curr_var else: print(saved_var_name) print(var_shape, saved_shapes[saved_var_name]) saver = tf.train.Saver(v_map) saver.restore(session, save_file) def remap_restore(session, save_file, i): reader = tf.train.NewCheckpointReader(save_file) saved_shapes = reader.get_variable_to_shape_map() var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables() if var.name.split(':')[0] in saved_shapes]) restore_vars = [] with tf.variable_scope('', reuse=True): for var_name, saved_var_name in var_names: try: curr_var = tf.get_variable(saved_var_name) except Exception as e: print(e) continue var_shape = curr_var.get_shape().as_list() if var_shape == saved_shapes[saved_var_name]: restore_vars.append(curr_var) print(restore_vars) saver = tf.train.Saver(restore_vars) saver.restore(session, save_file) # Network weight initializers def init_conv_weight( weights, scope, k, c_in, c_out, spec_norm=True, zero=False, scale=1.0, classes=1): if spec_norm: spec_norm = FLAGS.spec_norm conv_weights = {} with tf.variable_scope(scope): if zero: conv_weights['c'] = get_weight( 'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True) else: conv_weights['c'] = get_weight( 'c', [k, k, c_in, c_out], spec_norm=spec_norm) conv_weights['b'] = tf.get_variable( shape=[c_out], name='b', initializer=tf.initializers.zeros()) if classes != 1: conv_weights['g'] = tf.get_variable( shape=[ classes, c_out], name='g', initializer=tf.initializers.ones()) conv_weights['gb'] = tf.get_variable( shape=[ classes, c_in], name='gb', initializer=tf.initializers.zeros()) else: conv_weights['g'] = tf.get_variable( shape=[c_out], name='g', initializer=tf.initializers.ones()) conv_weights['gb'] = tf.get_variable( shape=[c_in], name='gb', initializer=tf.initializers.zeros()) conv_weights['cb'] = tf.get_variable( shape=[c_in], name='cb', initializer=tf.initializers.zeros()) weights[scope] = conv_weights def init_convt_weight( weights, scope, k, c_in, c_out, spec_norm=True, zero=False, scale=1.0, classes=1): if spec_norm: spec_norm = FLAGS.spec_norm conv_weights = {} with tf.variable_scope(scope): if zero: conv_weights['c'] = get_weight( 'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True) else: conv_weights['c'] = get_weight( 'c', [k, k, c_in, c_out], spec_norm=spec_norm) conv_weights['b'] = tf.get_variable( shape=[c_in], name='b', initializer=tf.initializers.zeros()) if classes != 1: conv_weights['g'] = tf.get_variable( shape=[ classes, c_in], name='g', initializer=tf.initializers.ones()) conv_weights['gb'] = tf.get_variable( shape=[ classes, c_out], name='gb', initializer=tf.initializers.zeros()) else: conv_weights['g'] = tf.get_variable( shape=[c_in], name='g', initializer=tf.initializers.ones()) conv_weights['gb'] = tf.get_variable( shape=[c_out], name='gb', initializer=tf.initializers.zeros()) conv_weights['cb'] = tf.get_variable( shape=[c_in], name='cb', initializer=tf.initializers.zeros()) weights[scope] = conv_weights def init_attention_weight( weights, scope, c_in, k, trainable_gamma=True, spec_norm=True): if spec_norm: spec_norm = FLAGS.spec_norm atten_weights = {} with tf.variable_scope(scope): atten_weights['q'] = get_weight( 'atten_q', [1, 1, c_in, k], spec_norm=spec_norm) atten_weights['q_b'] = tf.get_variable( shape=[k], name='atten_q_b1', initializer=tf.initializers.zeros()) atten_weights['k'] = get_weight( 'atten_k', [1, 1, c_in, k], spec_norm=spec_norm) atten_weights['k_b'] = tf.get_variable( shape=[k], name='atten_k_b1', initializer=tf.initializers.zeros()) atten_weights['v'] = get_weight( 'atten_v', [1, 1, c_in, c_in], spec_norm=spec_norm) atten_weights['v_b'] = tf.get_variable( shape=[c_in], name='atten_v_b1', initializer=tf.initializers.zeros()) atten_weights['gamma'] = tf.get_variable( shape=[1], name='gamma', initializer=tf.initializers.zeros()) weights[scope] = atten_weights def init_fc_weight(weights, scope, c_in, c_out, spec_norm=True): fc_weights = {} if spec_norm: spec_norm = FLAGS.spec_norm with tf.variable_scope(scope): fc_weights['w'] = get_weight( 'w', [c_in, c_out], spec_norm=spec_norm, fc=True) fc_weights['b'] = tf.get_variable( shape=[c_out], name='b', initializer=tf.initializers.zeros()) weights[scope] = fc_weights def init_res_weight( weights, scope, k, c_in, c_out, hidden_dim=None, spec_norm=True, res_scale=1.0, classes=1): if not hidden_dim: hidden_dim = c_in if spec_norm: spec_norm = FLAGS.spec_norm init_conv_weight( weights, scope + '_res_c1', k, c_in, c_out, spec_norm=spec_norm, scale=res_scale, classes=classes) init_conv_weight( weights, scope + '_res_c2', k, c_out, c_out, spec_norm=spec_norm, zero=True, scale=res_scale, classes=classes) if c_in != c_out: init_conv_weight( weights, scope + '_res_adaptive', k, c_in, c_out, spec_norm=spec_norm, scale=res_scale, classes=classes) # Network forward helpers def smart_conv_block(inp, weights, reuse, scope, use_stride=True, **kwargs): weights = weights[scope] return conv_block( inp, weights['c'], weights['b'], reuse, scope, scale=weights['g'], bias=weights['gb'], class_bias=weights['cb'], use_stride=use_stride, **kwargs) def smart_convt_block( inp, weights, reuse, scope, output_dim, upsample=True, label=None): weights = weights[scope] cweight = weights['c'] bweight = weights['b'] scale = weights['g'] bias = weights['gb'] class_bias = weights['cb'] if upsample: stride = [1, 2, 2, 1] else: stride = [1, 1, 1, 1] if label is not None: bias_batch = tf.matmul(label, bias) batch = tf.shape(bias_batch)[0] dim = tf.shape(bias_batch)[1] bias = tf.reshape(bias_batch, (batch, 1, 1, dim)) inp = inp + bias shape = cweight.get_shape() conv_output = tf.nn.conv2d_transpose(inp, cweight, [tf.shape(inp)[0], output_dim, output_dim, cweight.get_shape().as_list()[-2]], stride, 'SAME') if label is not None: scale_batch = tf.matmul(label, scale) + class_bias batch = tf.shape(scale_batch)[0] dim = tf.shape(scale_batch)[1] scale = tf.reshape(scale_batch, (batch, 1, 1, dim)) conv_output = conv_output * scale conv_output = tf.nn.leaky_relu(conv_output) return conv_output def smart_res_block( inp, weights, reuse, scope, downsample=True, adaptive=True, stop_batch=False, upsample=False, label=None, act=tf.nn.leaky_relu, dropout=False, train=False, **kwargs): gn1 = weights[scope + '_res_c1'] gn2 = weights[scope + '_res_c2'] c1 = smart_conv_block( inp, weights, reuse, scope + '_res_c1', use_stride=False, activation=None, extra_bias=True, label=label, **kwargs) if dropout: c1 = tf.layers.dropout(c1, rate=0.5, training=train) c1 = act(c1) c2 = smart_conv_block( c1, weights, reuse, scope + '_res_c2', use_stride=False, activation=None, use_scale=True, extra_bias=True, label=label, **kwargs) if adaptive: c_bypass = smart_conv_block( inp, weights, reuse, scope + '_res_adaptive', use_stride=False, activation=None, **kwargs) else: c_bypass = inp res = c2 + c_bypass if upsample: res_shape = tf.shape(res) res_shape_list = res.get_shape() res = tf.image.resize_nearest_neighbor( res, [2 * res_shape_list[1], 2 * res_shape_list[2]]) elif downsample: res = tf.nn.avg_pool(res, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID') res = act(res) return res def smart_res_block_optim(inp, weights, reuse, scope, **kwargs): c1 = smart_conv_block( inp, weights, reuse, scope + '_res_c1', use_stride=False, activation=None, **kwargs) c1 = tf.nn.leaky_relu(c1) c2 = smart_conv_block( c1, weights, reuse, scope + '_res_c2', use_stride=False, activation=None, **kwargs) inp = tf.nn.avg_pool(inp, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID') c_bypass = smart_conv_block( inp, weights, reuse, scope + '_res_adaptive', use_stride=False, activation=None, **kwargs) c2 = tf.nn.avg_pool(c2, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID') res = c2 + c_bypass return c2 def groupsort(k=4): def sortact(inp): old_shape = tf.shape(inp) inp = sort(tf.reshape(inp, (-1, 4))) inp = tf.reshape(inp, old_shape) return inp return sortact def smart_atten_block(inp, weights, reuse, scope, **kwargs): w = weights[scope] return attention( inp, w['q'], w['q_b'], w['k'], w['k_b'], w['v'], w['v_b'], w['gamma'], reuse, scope, **kwargs) def smart_fc_block(inp, weights, reuse, scope, use_bias=True): weights = weights[scope] output = tf.matmul(inp, weights['w']) if use_bias: output = output + weights['b'] return output # Network helpers def conv_block( inp, cweight, bweight, reuse, scope, use_stride=True, activation=tf.nn.leaky_relu, pn=False, bn=False, gn=False, ln=False, scale=None, bias=None, class_bias=None, use_bias=False, downsample=False, stop_batch=False, use_scale=False, extra_bias=False, average=False, label=None): """ Perform, conv, batch norm, nonlinearity, and max pool """ stride, no_stride = [1, 2, 2, 1], [1, 1, 1, 1] _, h, w, _ = inp.get_shape() if FLAGS.downsample: stride = no_stride if not use_bias: bweight = 0 if extra_bias: if label is not None: if len(bias.get_shape()) == 1: bias = tf.reshape(bias, (1, -1)) bias_batch = tf.matmul(label, bias) batch = tf.shape(bias_batch)[0] dim = tf.shape(bias_batch)[1] bias = tf.reshape(bias_batch, (batch, 1, 1, dim)) inp = inp + bias if not use_stride: conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') else: conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') if use_scale: if label is not None: if len(scale.get_shape()) == 1: scale = tf.reshape(scale, (1, -1)) scale_batch = tf.matmul(label, scale) + class_bias batch = tf.shape(scale_batch)[0] dim = tf.shape(scale_batch)[1] scale = tf.reshape(scale_batch, (batch, 1, 1, dim)) conv_output = conv_output * scale if use_bias: conv_output = conv_output + bweight if activation is not None: conv_output = activation(conv_output) if bn: conv_output = batch_norm(conv_output, scale, bias) if pn: conv_output = pixel_norm(conv_output) if gn: conv_output = group_norm( conv_output, scale, bias, stop_batch=stop_batch) if ln: conv_output = layer_norm(conv_output, scale, bias) if FLAGS.downsample and use_stride: conv_output = tf.layers.average_pooling2d(conv_output, (2, 2), 2) return conv_output def conv_block_1d( inp, cweight, bweight, reuse, scope, activation=tf.nn.leaky_relu): """ Perform, conv, batch norm, nonlinearity, and max pool """ stride = 1 conv_output = tf.nn.conv1d(inp, cweight, stride, 'SAME') + bweight if activation is not None: conv_output = activation(conv_output) return conv_output def conv_block_3d( inp, cweight, bweight, reuse, scope, use_stride=True, activation=tf.nn.leaky_relu, pn=False, bn=False, gn=False, ln=False, scale=None, bias=None, use_bias=False): """ Perform, conv, batch norm, nonlinearity, and max pool """ stride, no_stride = [1, 1, 2, 2, 1], [1, 1, 1, 1, 1] _, d, h, w, _ = inp.get_shape() if not use_bias: bweight = 0 if not use_stride: conv_output = tf.nn.conv3d(inp, cweight, no_stride, 'SAME') + bweight else: conv_output = tf.nn.conv3d(inp, cweight, stride, 'SAME') + bweight if activation is not None: conv_output = activation(conv_output, alpha=0.1) if bn: conv_output = batch_norm(conv_output, scale, bias) if pn: conv_output = pixel_norm(conv_output) if gn: conv_output = group_norm(conv_output, scale, bias) if ln: conv_output = layer_norm(conv_output, scale, bias) if FLAGS.downsample and use_stride: conv_output = tf.layers.average_pooling2d(conv_output, (2, 2), 2) return conv_output def group_norm(inp, scale, bias, g=32, eps=1e-6, stop_batch=False): """Applies group normalization assuming nhwc format""" n, h, w, c = inp.shape inp = tf.reshape(inp, (tf.shape(inp)[0], h, w, c // g, g)) mean, var = tf.nn.moments(inp, [1, 2, 4], keep_dims=True) gain = tf.rsqrt(var + eps) # if stop_batch: # gain = tf.stop_gradient(gain) output = gain * (inp - mean) output = tf.reshape(output, (tf.shape(inp)[0], h, w, c)) if scale is not None: output = output * scale if bias is not None: output = output + bias return output def layer_norm(inp, scale, bias, eps=1e-6): """Applies group normalization assuming nhwc format""" n, h, w, c = inp.shape mean, var = tf.nn.moments(inp, [1, 2, 3], keep_dims=True) gain = tf.rsqrt(var + eps) output = gain * (inp - mean) if scale is not None: output = output * scale if bias is not None: output = output + bias return output def conv_cond_concat(x, y): """Concatenate conditioning vector on feature map axis.""" x_shapes = tf.shape(x) y_shapes = tf.shape(y) return tf.concat( [x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]]) / 10.], 3) def attention( inp, q, q_b, k, k_b, v, v_b, gamma, reuse, scope, stop_at_grad=False, seperate=False, scale=False, train=False, dropout=0.0): conv_q = conv_block( inp, q, q_b, reuse=reuse, scope=scope, use_stride=False, activation=None, use_bias=True, pn=False, bn=False, gn=False) conv_k = conv_block( inp, k, k_b, reuse=reuse, scope=scope, use_stride=False, activation=None, use_bias=True, pn=False, bn=False, gn=False) conv_v = conv_block( inp, v, v_b, reuse=reuse, scope=scope, use_stride=False, pn=False, bn=False, gn=False) c_num = float(conv_q.get_shape().as_list()[-1]) s = tf.matmul(hw_flatten(conv_q), hw_flatten(conv_k), transpose_b=True) if scale: s = s / (c_num) ** 0.5 if train: s = tf.nn.dropout(s, 0.9) beta = tf.nn.softmax(s, axis=-1) o = tf.matmul(beta, hw_flatten(conv_v)) o = tf.reshape(o, shape=tf.shape(inp)) inp = inp + gamma * o if not seperate: return inp else: return gamma * o def attention_2d( inp, q, q_b, k, k_b, v, v_b, reuse, scope, stop_at_grad=False, seperate=False, scale=False): inp_shape = tf.shape(inp) inp_compact = tf.reshape( inp, (inp_shape[0] * FLAGS.input_objects * inp_shape[1], inp.shape[3])) f_q = tf.matmul(inp_compact, q) + q_b f_k = tf.matmul(inp_compact, k) + k_b f_v = tf.nn.leaky_relu(tf.matmul(inp_compact, v) + v_b) f_q = tf.reshape(f_q, (inp_shape[0], inp_shape[1], inp_shape[2], tf.shape(f_q)[-1])) f_k = tf.reshape(f_k, (inp_shape[0], inp_shape[1], inp_shape[2], tf.shape(f_k)[-1])) f_v = tf.reshape( f_v, (inp_shape[0], inp_shape[1], inp_shape[2], inp_shape[3])) s = tf.matmul(f_k, f_q, transpose_b=True) c_num = (32**0.5) if scale: s = s / c_num beta = tf.nn.softmax(s, axis=-1) o = tf.reshape(tf.matmul(beta, f_v), inp_shape) + inp return o def hw_flatten(x): shape = tf.shape(x) return tf.reshape(x, [tf.shape(x)[0], -1, shape[-1]]) def batch_norm(inp, scale, bias, eps=0.01): mean, var = tf.nn.moments(inp, [0]) output = tf.nn.batch_normalization(inp, mean, var, bias, scale, eps) return output def normalize(inp, activation, reuse, scope): if FLAGS.norm == 'batch_norm': return tf_layers.batch_norm( inp, activation_fn=activation, reuse=reuse, scope=scope) elif FLAGS.norm == 'layer_norm': return tf_layers.layer_norm( inp, activation_fn=activation, reuse=reuse, scope=scope) elif FLAGS.norm == 'None': if activation is not None: return activation(inp) else: return inp # Loss functions def mse(pred, label): pred = tf.reshape(pred, [-1]) label = tf.reshape(label, [-1]) return tf.reduce_mean(tf.square(pred - label)) NO_OPS = 'NO_OPS' def _l2normalize(v, eps=1e-12): return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) def spectral_normed_weight(w, name, lower_bound=False, iteration=1, fc=False): if fc: iteration = 2 w_shape = w.shape.as_list() w = tf.reshape(w, [-1, w_shape[-1]]) iteration = FLAGS.spec_iter sigma_new = FLAGS.spec_norm_val u = tf.get_variable(name + "_u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) u_hat = u v_hat = None for i in range(iteration): """ power iteration Usually iteration = 1 will be enough """ v_ = tf.matmul(u_hat, tf.transpose(w)) v_hat = tf.nn.l2_normalize(v_) u_ = tf.matmul(v_hat, w) u_hat = tf.nn.l2_normalize(u_) u_hat = tf.stop_gradient(u_hat) v_hat = tf.stop_gradient(v_hat) sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) if FLAGS.spec_eval: dep = [] else: dep = [u.assign(u_hat)] with tf.control_dependencies(dep): if lower_bound: sigma = sigma + 1e-6 w_norm = w / sigma * tf.minimum(sigma, 1) * sigma_new else: w_norm = w / sigma * sigma_new w_norm = tf.reshape(w_norm, w_shape) return w_norm def average_gradients(tower_grads): """Calculate the average gradient for each shared variable across all towers. Note that this function provides a synchronization point across all towers. Args: tower_grads: List of lists of (gradient, variable) tuples. The outer list is over individual gradients. The inner list is over the gradient calculation for each tower. Returns: List of pairs of (gradient, variable) where the gradient has been averaged across all towers. """ average_grads = [] for grad_and_vars in zip(*tower_grads): # Note that each grad_and_vars looks like the following: # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) grads = [] for g, v in grad_and_vars: if g is not None: # Add 0 dimension to the gradients to represent the tower. expanded_g = tf.expand_dims(g, 0) # Append on a 'tower' dimension which we will average over # below. grads.append(expanded_g) else: print(g, v) # Average over the 'tower' dimension. grad = tf.concat(axis=0, values=grads) grad = tf.reduce_mean(grad, 0) # Keep in mind that the Variables are redundant because they are shared # across towers. So .. we will just return the first tower's pointer to # the Variable. v = grad_and_vars[0][1] grad_and_var = (grad, v) average_grads.append(grad_and_var) return average_grads