mirror of
				https://github.com/autistic-symposium/ml-ai-agents-py.git
				synced 2025-10-25 11:26:15 -04:00 
			
		
		
		
	
		
			
				
	
	
		
			1107 lines
		
	
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1107 lines
		
	
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """ Utility functions. """
 | |
| import numpy as np
 | |
| import os
 | |
| import random
 | |
| import tensorflow as tf
 | |
| import warnings
 | |
| 
 | |
| from tensorflow.contrib.layers.python import layers as tf_layers
 | |
| from tensorflow.python.platform import flags
 | |
| from tensorflow.contrib.framework import sort
 | |
| 
 | |
| FLAGS = flags.FLAGS
 | |
| flags.DEFINE_integer('spec_iter', 1, 'Number of iterations to normalize spectrum of matrix')
 | |
| flags.DEFINE_float('spec_norm_val', 1.0, 'Desired norm of matrices')
 | |
| flags.DEFINE_bool('downsample', False, 'Wheter to do average pool downsampling')
 | |
| flags.DEFINE_bool('spec_eval', False, 'Set to true to prevent spectral updates')
 | |
| 
 | |
| 
 | |
| def get_median(v):
 | |
|     v = tf.reshape(v, [-1])
 | |
|     m = tf.shape(v)[0] // 2
 | |
|     return tf.nn.top_k(v, m)[m - 1]
 | |
| 
 | |
| 
 | |
| def set_seed(seed):
 | |
|     import torch
 | |
|     import numpy
 | |
|     import random
 | |
| 
 | |
|     torch.manual_seed(seed)
 | |
|     numpy.random.seed(seed)
 | |
|     random.seed(seed)
 | |
|     tf.set_random_seed(seed)
 | |
| 
 | |
| 
 | |
| def swish(inp):
 | |
|     return inp * tf.nn.sigmoid(inp)
 | |
| 
 | |
| 
 | |
| class ReplayBuffer(object):
 | |
|     def __init__(self, size):
 | |
|         """Create Replay buffer.
 | |
|         Parameters
 | |
|         ----------
 | |
|         size: int
 | |
|             Max number of transitions to store in the buffer. When the buffer
 | |
|             overflows the old memories are dropped.
 | |
|         """
 | |
|         self._storage = []
 | |
|         self._maxsize = size
 | |
|         self._next_idx = 0
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self._storage)
 | |
| 
 | |
|     def add(self, ims):
 | |
|         batch_size = ims.shape[0]
 | |
|         if self._next_idx >= len(self._storage):
 | |
|             self._storage.extend(list(ims))
 | |
|         else:
 | |
|             if batch_size + self._next_idx < self._maxsize:
 | |
|                 self._storage[self._next_idx:self._next_idx +
 | |
|                               batch_size] = list(ims)
 | |
|             else:
 | |
|                 split_idx = self._maxsize - self._next_idx
 | |
|                 self._storage[self._next_idx:] = list(ims)[:split_idx]
 | |
|                 self._storage[:batch_size - split_idx] = list(ims)[split_idx:]
 | |
|         self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
 | |
| 
 | |
|     def _encode_sample(self, idxes):
 | |
|         ims = []
 | |
|         for i in idxes:
 | |
|             ims.append(self._storage[i])
 | |
|         return np.array(ims)
 | |
| 
 | |
|     def sample(self, batch_size):
 | |
|         """Sample a batch of experiences.
 | |
|         Parameters
 | |
|         ----------
 | |
|         batch_size: int
 | |
|             How many transitions to sample.
 | |
|         Returns
 | |
|         -------
 | |
|         obs_batch: np.array
 | |
|             batch of observations
 | |
|         act_batch: np.array
 | |
|             batch of actions executed given obs_batch
 | |
|         rew_batch: np.array
 | |
|             rewards received as results of executing act_batch
 | |
|         next_obs_batch: np.array
 | |
|             next set of observations seen after executing act_batch
 | |
|         done_mask: np.array
 | |
|             done_mask[i] = 1 if executing act_batch[i] resulted in
 | |
|             the end of an episode and 0 otherwise.
 | |
|         """
 | |
|         idxes = [random.randint(0, len(self._storage) - 1)
 | |
|                  for _ in range(batch_size)]
 | |
|         return self._encode_sample(idxes)
 | |
| 
 | |
| 
 | |
| def get_weight(
 | |
|         name,
 | |
|         shape,
 | |
|         gain=np.sqrt(2),
 | |
|         use_wscale=False,
 | |
|         fan_in=None,
 | |
|         spec_norm=False,
 | |
|         zero=False,
 | |
|         fc=False):
 | |
|     if fan_in is None:
 | |
|         fan_in = np.prod(shape[:-1])
 | |
|     std = gain / np.sqrt(fan_in)  # He init
 | |
|     if use_wscale:
 | |
|         wscale = tf.constant(np.float32(std), name=name + 'wscale')
 | |
|         var = tf.get_variable(
 | |
|             name + 'weight',
 | |
|             shape=shape,
 | |
|             initializer=tf.initializers.random_normal()) * wscale
 | |
|     elif spec_norm:
 | |
|         if zero:
 | |
|             var = tf.get_variable(
 | |
|                 shape=shape,
 | |
|                 name=name + 'weight',
 | |
|                 initializer=tf.initializers.random_normal(
 | |
|                     stddev=1e-10))
 | |
|             var = spectral_normed_weight(var, name, lower_bound=True, fc=fc)
 | |
|         else:
 | |
|             var = tf.get_variable(
 | |
|                 name + 'weight',
 | |
|                 shape=shape,
 | |
|                 initializer=tf.initializers.random_normal())
 | |
|             var = spectral_normed_weight(var, name, fc=fc)
 | |
|     else:
 | |
|         if zero:
 | |
|             var = tf.get_variable(
 | |
|                 name + 'weight',
 | |
|                 shape=shape,
 | |
|                 initializer=tf.initializers.zero())
 | |
|         else:
 | |
|             var = tf.get_variable(
 | |
|                 name + 'weight',
 | |
|                 shape=shape,
 | |
|                 initializer=tf.contrib.layers.xavier_initializer(
 | |
|                     dtype=tf.float32))
 | |
| 
 | |
|     return var
 | |
| 
 | |
| 
 | |
| def pixel_norm(x, epsilon=1e-8):
 | |
|     with tf.variable_scope('PixelNorm'):
 | |
|         return x * tf.rsqrt(tf.reduce_mean(tf.square(x),
 | |
|                                            axis=[1, 2], keepdims=True) + epsilon)
 | |
| 
 | |
| 
 | |
| # helper
 | |
| def get_images(paths, labels, nb_samples=None, shuffle=True):
 | |
|     if nb_samples is not None:
 | |
|         def sampler(x): return random.sample(x, nb_samples)
 | |
|     else:
 | |
|         def sampler(x): return x
 | |
|     images = [(i, os.path.join(path, image))
 | |
|               for i, path in zip(labels, paths)
 | |
|               for image in sampler(os.listdir(path))]
 | |
|     if shuffle:
 | |
|         random.shuffle(images)
 | |
|     return images
 | |
| 
 | |
| 
 | |
| def optimistic_restore(session, save_file, v_prefix=None):
 | |
|     reader = tf.train.NewCheckpointReader(save_file)
 | |
|     saved_shapes = reader.get_variable_to_shape_map()
 | |
| 
 | |
|     var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.get_collection(
 | |
|         tf.GraphKeys.GLOBAL_VARIABLES) if var.name.split(':')[0] in saved_shapes])
 | |
|     restore_vars = []
 | |
|     with tf.variable_scope('', reuse=True):
 | |
|         for var_name, saved_var_name in var_names:
 | |
|             try:
 | |
|                 curr_var = tf.get_variable(saved_var_name)
 | |
|             except Exception as e:
 | |
|                 print(e)
 | |
|                 continue
 | |
|             var_shape = curr_var.get_shape().as_list()
 | |
|             if var_shape == saved_shapes[saved_var_name]:
 | |
|                 restore_vars.append(curr_var)
 | |
|             else:
 | |
|                 print(var_name)
 | |
|                 print(var_shape, saved_shapes[saved_var_name])
 | |
| 
 | |
|     saver = tf.train.Saver(restore_vars)
 | |
|     saver.restore(session, save_file)
 | |
| 
 | |
| 
 | |
| def optimistic_remap_restore(session, save_file, v_prefix):
 | |
|     reader = tf.train.NewCheckpointReader(save_file)
 | |
|     saved_shapes = reader.get_variable_to_shape_map()
 | |
| 
 | |
|     vars_list = tf.get_collection(
 | |
|         tf.GraphKeys.GLOBAL_VARIABLES,
 | |
|         scope='context_{}'.format(v_prefix))
 | |
|     var_names = sorted([(var.name.split(':')[0], var) for var in vars_list if (
 | |
|         (var.name.split(':')[0]).replace('context_{}'.format(v_prefix), 'context_0') in saved_shapes)])
 | |
|     restore_vars = []
 | |
| 
 | |
|     v_map = {}
 | |
|     with tf.variable_scope('', reuse=True):
 | |
|         for saved_var_name, curr_var in var_names:
 | |
|             var_shape = curr_var.get_shape().as_list()
 | |
|             saved_var_name = saved_var_name.replace(
 | |
|                 'context_{}'.format(v_prefix), 'context_0')
 | |
|             if var_shape == saved_shapes[saved_var_name]:
 | |
|                 v_map[saved_var_name] = curr_var
 | |
|             else:
 | |
|                 print(saved_var_name)
 | |
|                 print(var_shape, saved_shapes[saved_var_name])
 | |
| 
 | |
|     saver = tf.train.Saver(v_map)
 | |
|     saver.restore(session, save_file)
 | |
| 
 | |
| 
 | |
| def remap_restore(session, save_file, i):
 | |
|     reader = tf.train.NewCheckpointReader(save_file)
 | |
|     saved_shapes = reader.get_variable_to_shape_map()
 | |
|     var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
 | |
|                         if var.name.split(':')[0] in saved_shapes])
 | |
|     restore_vars = []
 | |
|     with tf.variable_scope('', reuse=True):
 | |
|         for var_name, saved_var_name in var_names:
 | |
|             try:
 | |
|                 curr_var = tf.get_variable(saved_var_name)
 | |
|             except Exception as e:
 | |
|                 print(e)
 | |
|                 continue
 | |
|             var_shape = curr_var.get_shape().as_list()
 | |
|             if var_shape == saved_shapes[saved_var_name]:
 | |
|                 restore_vars.append(curr_var)
 | |
|     print(restore_vars)
 | |
|     saver = tf.train.Saver(restore_vars)
 | |
|     saver.restore(session, save_file)
 | |
| 
 | |
| 
 | |
| # Network weight initializers
 | |
| def init_conv_weight(
 | |
|         weights,
 | |
|         scope,
 | |
|         k,
 | |
|         c_in,
 | |
|         c_out,
 | |
|         spec_norm=True,
 | |
|         zero=False,
 | |
|         scale=1.0,
 | |
|         classes=1):
 | |
| 
 | |
|     if spec_norm:
 | |
|         spec_norm = FLAGS.spec_norm
 | |
| 
 | |
|     conv_weights = {}
 | |
|     with tf.variable_scope(scope):
 | |
|         if zero:
 | |
|             conv_weights['c'] = get_weight(
 | |
|                 'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True)
 | |
|         else:
 | |
|             conv_weights['c'] = get_weight(
 | |
|                 'c', [k, k, c_in, c_out], spec_norm=spec_norm)
 | |
| 
 | |
|         conv_weights['b'] = tf.get_variable(
 | |
|             shape=[c_out], name='b', initializer=tf.initializers.zeros())
 | |
| 
 | |
|         if classes != 1:
 | |
|             conv_weights['g'] = tf.get_variable(
 | |
|                 shape=[
 | |
|                     classes,
 | |
|                     c_out],
 | |
|                 name='g',
 | |
|                 initializer=tf.initializers.ones())
 | |
|             conv_weights['gb'] = tf.get_variable(
 | |
|                 shape=[
 | |
|                     classes,
 | |
|                     c_in],
 | |
|                 name='gb',
 | |
|                 initializer=tf.initializers.zeros())
 | |
|         else:
 | |
|             conv_weights['g'] = tf.get_variable(
 | |
|                 shape=[c_out], name='g', initializer=tf.initializers.ones())
 | |
|             conv_weights['gb'] = tf.get_variable(
 | |
|                 shape=[c_in], name='gb', initializer=tf.initializers.zeros())
 | |
| 
 | |
|         conv_weights['cb'] = tf.get_variable(
 | |
|             shape=[c_in], name='cb', initializer=tf.initializers.zeros())
 | |
| 
 | |
|     weights[scope] = conv_weights
 | |
| 
 | |
| 
 | |
| def init_convt_weight(
 | |
|         weights,
 | |
|         scope,
 | |
|         k,
 | |
|         c_in,
 | |
|         c_out,
 | |
|         spec_norm=True,
 | |
|         zero=False,
 | |
|         scale=1.0,
 | |
|         classes=1):
 | |
| 
 | |
|     if spec_norm:
 | |
|         spec_norm = FLAGS.spec_norm
 | |
| 
 | |
|     conv_weights = {}
 | |
|     with tf.variable_scope(scope):
 | |
|         if zero:
 | |
|             conv_weights['c'] = get_weight(
 | |
|                 'c', [k, k, c_in, c_out], spec_norm=spec_norm, zero=True)
 | |
|         else:
 | |
|             conv_weights['c'] = get_weight(
 | |
|                 'c', [k, k, c_in, c_out], spec_norm=spec_norm)
 | |
| 
 | |
|         conv_weights['b'] = tf.get_variable(
 | |
|             shape=[c_in], name='b', initializer=tf.initializers.zeros())
 | |
| 
 | |
|         if classes != 1:
 | |
|             conv_weights['g'] = tf.get_variable(
 | |
|                 shape=[
 | |
|                     classes,
 | |
|                     c_in],
 | |
|                 name='g',
 | |
|                 initializer=tf.initializers.ones())
 | |
|             conv_weights['gb'] = tf.get_variable(
 | |
|                 shape=[
 | |
|                     classes,
 | |
|                     c_out],
 | |
|                 name='gb',
 | |
|                 initializer=tf.initializers.zeros())
 | |
|         else:
 | |
|             conv_weights['g'] = tf.get_variable(
 | |
|                 shape=[c_in], name='g', initializer=tf.initializers.ones())
 | |
|             conv_weights['gb'] = tf.get_variable(
 | |
|                 shape=[c_out], name='gb', initializer=tf.initializers.zeros())
 | |
| 
 | |
|         conv_weights['cb'] = tf.get_variable(
 | |
|             shape=[c_in], name='cb', initializer=tf.initializers.zeros())
 | |
| 
 | |
|     weights[scope] = conv_weights
 | |
| 
 | |
| 
 | |
| def init_attention_weight(
 | |
|         weights,
 | |
|         scope,
 | |
|         c_in,
 | |
|         k,
 | |
|         trainable_gamma=True,
 | |
|         spec_norm=True):
 | |
| 
 | |
|     if spec_norm:
 | |
|         spec_norm = FLAGS.spec_norm
 | |
| 
 | |
|     atten_weights = {}
 | |
|     with tf.variable_scope(scope):
 | |
|         atten_weights['q'] = get_weight(
 | |
|             'atten_q', [1, 1, c_in, k], spec_norm=spec_norm)
 | |
|         atten_weights['q_b'] = tf.get_variable(
 | |
|             shape=[k], name='atten_q_b1', initializer=tf.initializers.zeros())
 | |
|         atten_weights['k'] = get_weight(
 | |
|             'atten_k', [1, 1, c_in, k], spec_norm=spec_norm)
 | |
|         atten_weights['k_b'] = tf.get_variable(
 | |
|             shape=[k], name='atten_k_b1', initializer=tf.initializers.zeros())
 | |
|         atten_weights['v'] = get_weight(
 | |
|             'atten_v', [1, 1, c_in, c_in], spec_norm=spec_norm)
 | |
|         atten_weights['v_b'] = tf.get_variable(
 | |
|             shape=[c_in], name='atten_v_b1', initializer=tf.initializers.zeros())
 | |
|         atten_weights['gamma'] = tf.get_variable(
 | |
|             shape=[1], name='gamma', initializer=tf.initializers.zeros())
 | |
| 
 | |
|     weights[scope] = atten_weights
 | |
| 
 | |
| 
 | |
| def init_fc_weight(weights, scope, c_in, c_out, spec_norm=True):
 | |
|     fc_weights = {}
 | |
| 
 | |
|     if spec_norm:
 | |
|         spec_norm = FLAGS.spec_norm
 | |
| 
 | |
|     with tf.variable_scope(scope):
 | |
|         fc_weights['w'] = get_weight(
 | |
|             'w', [c_in, c_out], spec_norm=spec_norm, fc=True)
 | |
|         fc_weights['b'] = tf.get_variable(
 | |
|             shape=[c_out], name='b', initializer=tf.initializers.zeros())
 | |
| 
 | |
|     weights[scope] = fc_weights
 | |
| 
 | |
| 
 | |
| def init_res_weight(
 | |
|         weights,
 | |
|         scope,
 | |
|         k,
 | |
|         c_in,
 | |
|         c_out,
 | |
|         hidden_dim=None,
 | |
|         spec_norm=True,
 | |
|         res_scale=1.0,
 | |
|         classes=1):
 | |
| 
 | |
|     if not hidden_dim:
 | |
|         hidden_dim = c_in
 | |
| 
 | |
|     if spec_norm:
 | |
|         spec_norm = FLAGS.spec_norm
 | |
| 
 | |
|     init_conv_weight(
 | |
|         weights,
 | |
|         scope +
 | |
|         '_res_c1',
 | |
|         k,
 | |
|         c_in,
 | |
|         c_out,
 | |
|         spec_norm=spec_norm,
 | |
|         scale=res_scale,
 | |
|         classes=classes)
 | |
|     init_conv_weight(
 | |
|         weights,
 | |
|         scope + '_res_c2',
 | |
|         k,
 | |
|         c_out,
 | |
|         c_out,
 | |
|         spec_norm=spec_norm,
 | |
|         zero=True,
 | |
|         scale=res_scale,
 | |
|         classes=classes)
 | |
| 
 | |
|     if c_in != c_out:
 | |
|         init_conv_weight(
 | |
|             weights,
 | |
|             scope +
 | |
|             '_res_adaptive',
 | |
|             k,
 | |
|             c_in,
 | |
|             c_out,
 | |
|             spec_norm=spec_norm,
 | |
|             scale=res_scale,
 | |
|             classes=classes)
 | |
| 
 | |
| # Network forward helpers
 | |
| 
 | |
| 
 | |
| def smart_conv_block(inp, weights, reuse, scope, use_stride=True, **kwargs):
 | |
|     weights = weights[scope]
 | |
|     return conv_block(
 | |
|         inp,
 | |
|         weights['c'],
 | |
|         weights['b'],
 | |
|         reuse,
 | |
|         scope,
 | |
|         scale=weights['g'],
 | |
|         bias=weights['gb'],
 | |
|         class_bias=weights['cb'],
 | |
|         use_stride=use_stride,
 | |
|         **kwargs)
 | |
| 
 | |
| 
 | |
| def smart_convt_block(
 | |
|         inp,
 | |
|         weights,
 | |
|         reuse,
 | |
|         scope,
 | |
|         output_dim,
 | |
|         upsample=True,
 | |
|         label=None):
 | |
|     weights = weights[scope]
 | |
| 
 | |
|     cweight = weights['c']
 | |
|     bweight = weights['b']
 | |
|     scale = weights['g']
 | |
|     bias = weights['gb']
 | |
|     class_bias = weights['cb']
 | |
| 
 | |
|     if upsample:
 | |
|         stride = [1, 2, 2, 1]
 | |
|     else:
 | |
|         stride = [1, 1, 1, 1]
 | |
| 
 | |
|     if label is not None:
 | |
|         bias_batch = tf.matmul(label, bias)
 | |
|         batch = tf.shape(bias_batch)[0]
 | |
|         dim = tf.shape(bias_batch)[1]
 | |
|         bias = tf.reshape(bias_batch, (batch, 1, 1, dim))
 | |
| 
 | |
|         inp = inp + bias
 | |
| 
 | |
|     shape = cweight.get_shape()
 | |
|     conv_output = tf.nn.conv2d_transpose(inp,
 | |
|                                          cweight,
 | |
|                                          [tf.shape(inp)[0],
 | |
|                                           output_dim,
 | |
|                                           output_dim,
 | |
|                                           cweight.get_shape().as_list()[-2]],
 | |
|                                          stride,
 | |
|                                          'SAME')
 | |
| 
 | |
|     if label is not None:
 | |
|         scale_batch = tf.matmul(label, scale) + class_bias
 | |
|         batch = tf.shape(scale_batch)[0]
 | |
|         dim = tf.shape(scale_batch)[1]
 | |
|         scale = tf.reshape(scale_batch, (batch, 1, 1, dim))
 | |
| 
 | |
|         conv_output = conv_output * scale
 | |
| 
 | |
|     conv_output = tf.nn.leaky_relu(conv_output)
 | |
| 
 | |
|     return conv_output
 | |
| 
 | |
| 
 | |
| def smart_res_block(
 | |
|         inp,
 | |
|         weights,
 | |
|         reuse,
 | |
|         scope,
 | |
|         downsample=True,
 | |
|         adaptive=True,
 | |
|         stop_batch=False,
 | |
|         upsample=False,
 | |
|         label=None,
 | |
|         act=tf.nn.leaky_relu,
 | |
|         dropout=False,
 | |
|         train=False,
 | |
|         **kwargs):
 | |
|     gn1 = weights[scope + '_res_c1']
 | |
|     gn2 = weights[scope + '_res_c2']
 | |
|     c1 = smart_conv_block(
 | |
|         inp,
 | |
|         weights,
 | |
|         reuse,
 | |
|         scope + '_res_c1',
 | |
|         use_stride=False,
 | |
|         activation=None,
 | |
|         extra_bias=True,
 | |
|         label=label,
 | |
|         **kwargs)
 | |
| 
 | |
|     if dropout:
 | |
|         c1 = tf.layers.dropout(c1, rate=0.5, training=train)
 | |
| 
 | |
|     c1 = act(c1)
 | |
|     c2 = smart_conv_block(
 | |
|         c1,
 | |
|         weights,
 | |
|         reuse,
 | |
|         scope + '_res_c2',
 | |
|         use_stride=False,
 | |
|         activation=None,
 | |
|         use_scale=True,
 | |
|         extra_bias=True,
 | |
|         label=label,
 | |
|         **kwargs)
 | |
| 
 | |
|     if adaptive:
 | |
|         c_bypass = smart_conv_block(
 | |
|             inp,
 | |
|             weights,
 | |
|             reuse,
 | |
|             scope +
 | |
|             '_res_adaptive',
 | |
|             use_stride=False,
 | |
|             activation=None,
 | |
|             **kwargs)
 | |
|     else:
 | |
|         c_bypass = inp
 | |
| 
 | |
|     res = c2 + c_bypass
 | |
| 
 | |
|     if upsample:
 | |
|         res_shape = tf.shape(res)
 | |
|         res_shape_list = res.get_shape()
 | |
|         res = tf.image.resize_nearest_neighbor(
 | |
|             res, [2 * res_shape_list[1], 2 * res_shape_list[2]])
 | |
|     elif downsample:
 | |
|         res = tf.nn.avg_pool(res, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
 | |
| 
 | |
|     res = act(res)
 | |
| 
 | |
|     return res
 | |
| 
 | |
| 
 | |
| def smart_res_block_optim(inp, weights, reuse, scope, **kwargs):
 | |
|     c1 = smart_conv_block(
 | |
|         inp,
 | |
|         weights,
 | |
|         reuse,
 | |
|         scope + '_res_c1',
 | |
|         use_stride=False,
 | |
|         activation=None,
 | |
|         **kwargs)
 | |
|     c1 = tf.nn.leaky_relu(c1)
 | |
|     c2 = smart_conv_block(
 | |
|         c1,
 | |
|         weights,
 | |
|         reuse,
 | |
|         scope + '_res_c2',
 | |
|         use_stride=False,
 | |
|         activation=None,
 | |
|         **kwargs)
 | |
| 
 | |
|     inp = tf.nn.avg_pool(inp, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
 | |
|     c_bypass = smart_conv_block(
 | |
|         inp,
 | |
|         weights,
 | |
|         reuse,
 | |
|         scope +
 | |
|         '_res_adaptive',
 | |
|         use_stride=False,
 | |
|         activation=None,
 | |
|         **kwargs)
 | |
|     c2 = tf.nn.avg_pool(c2, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
 | |
| 
 | |
|     res = c2 + c_bypass
 | |
| 
 | |
|     return c2
 | |
| 
 | |
| 
 | |
| def groupsort(k=4):
 | |
|     def sortact(inp):
 | |
|         old_shape = tf.shape(inp)
 | |
|         inp = sort(tf.reshape(inp, (-1, 4)))
 | |
|         inp = tf.reshape(inp, old_shape)
 | |
|         return inp
 | |
|     return sortact
 | |
| 
 | |
| 
 | |
| def smart_atten_block(inp, weights, reuse, scope, **kwargs):
 | |
|     w = weights[scope]
 | |
|     return attention(
 | |
|         inp,
 | |
|         w['q'],
 | |
|         w['q_b'],
 | |
|         w['k'],
 | |
|         w['k_b'],
 | |
|         w['v'],
 | |
|         w['v_b'],
 | |
|         w['gamma'],
 | |
|         reuse,
 | |
|         scope,
 | |
|         **kwargs)
 | |
| 
 | |
| 
 | |
| def smart_fc_block(inp, weights, reuse, scope, use_bias=True):
 | |
|     weights = weights[scope]
 | |
|     output = tf.matmul(inp, weights['w'])
 | |
| 
 | |
|     if use_bias:
 | |
|         output = output + weights['b']
 | |
| 
 | |
|     return output
 | |
| 
 | |
| 
 | |
| # Network helpers
 | |
| def conv_block(
 | |
|         inp,
 | |
|         cweight,
 | |
|         bweight,
 | |
|         reuse,
 | |
|         scope,
 | |
|         use_stride=True,
 | |
|         activation=tf.nn.leaky_relu,
 | |
|         pn=False,
 | |
|         bn=False,
 | |
|         gn=False,
 | |
|         ln=False,
 | |
|         scale=None,
 | |
|         bias=None,
 | |
|         class_bias=None,
 | |
|         use_bias=False,
 | |
|         downsample=False,
 | |
|         stop_batch=False,
 | |
|         use_scale=False,
 | |
|         extra_bias=False,
 | |
|         average=False,
 | |
|         label=None):
 | |
|     """ Perform, conv, batch norm, nonlinearity, and max pool """
 | |
|     stride, no_stride = [1, 2, 2, 1], [1, 1, 1, 1]
 | |
|     _, h, w, _ = inp.get_shape()
 | |
| 
 | |
|     if FLAGS.downsample:
 | |
|         stride = no_stride
 | |
| 
 | |
|     if not use_bias:
 | |
|         bweight = 0
 | |
| 
 | |
|     if extra_bias:
 | |
|         if label is not None:
 | |
|             if len(bias.get_shape()) == 1:
 | |
|                 bias = tf.reshape(bias, (1, -1))
 | |
|             bias_batch = tf.matmul(label, bias)
 | |
|             batch = tf.shape(bias_batch)[0]
 | |
|             dim = tf.shape(bias_batch)[1]
 | |
|             bias = tf.reshape(bias_batch, (batch, 1, 1, dim))
 | |
| 
 | |
|         inp = inp + bias
 | |
| 
 | |
|     if not use_stride:
 | |
|         conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME')
 | |
|     else:
 | |
|         conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME')
 | |
| 
 | |
|     if use_scale:
 | |
|         if label is not None:
 | |
|             if len(scale.get_shape()) == 1:
 | |
|                 scale = tf.reshape(scale, (1, -1))
 | |
|             scale_batch = tf.matmul(label, scale) + class_bias
 | |
|             batch = tf.shape(scale_batch)[0]
 | |
|             dim = tf.shape(scale_batch)[1]
 | |
|             scale = tf.reshape(scale_batch, (batch, 1, 1, dim))
 | |
| 
 | |
|         conv_output = conv_output * scale
 | |
| 
 | |
|     if use_bias:
 | |
|         conv_output = conv_output + bweight
 | |
| 
 | |
|     if activation is not None:
 | |
|         conv_output = activation(conv_output)
 | |
| 
 | |
|     if bn:
 | |
|         conv_output = batch_norm(conv_output, scale, bias)
 | |
|     if pn:
 | |
|         conv_output = pixel_norm(conv_output)
 | |
|     if gn:
 | |
|         conv_output = group_norm(
 | |
|             conv_output, scale, bias, stop_batch=stop_batch)
 | |
|     if ln:
 | |
|         conv_output = layer_norm(conv_output, scale, bias)
 | |
| 
 | |
|     if FLAGS.downsample and use_stride:
 | |
|         conv_output = tf.layers.average_pooling2d(conv_output, (2, 2), 2)
 | |
| 
 | |
|     return conv_output
 | |
| 
 | |
| 
 | |
| def conv_block_1d(
 | |
|         inp,
 | |
|         cweight,
 | |
|         bweight,
 | |
|         reuse,
 | |
|         scope,
 | |
|         activation=tf.nn.leaky_relu):
 | |
|     """ Perform, conv, batch norm, nonlinearity, and max pool """
 | |
|     stride = 1
 | |
| 
 | |
|     conv_output = tf.nn.conv1d(inp, cweight, stride, 'SAME') + bweight
 | |
| 
 | |
|     if activation is not None:
 | |
|         conv_output = activation(conv_output)
 | |
| 
 | |
|     return conv_output
 | |
| 
 | |
| 
 | |
| def conv_block_3d(
 | |
|         inp,
 | |
|         cweight,
 | |
|         bweight,
 | |
|         reuse,
 | |
|         scope,
 | |
|         use_stride=True,
 | |
|         activation=tf.nn.leaky_relu,
 | |
|         pn=False,
 | |
|         bn=False,
 | |
|         gn=False,
 | |
|         ln=False,
 | |
|         scale=None,
 | |
|         bias=None,
 | |
|         use_bias=False):
 | |
|     """ Perform, conv, batch norm, nonlinearity, and max pool """
 | |
|     stride, no_stride = [1, 1, 2, 2, 1], [1, 1, 1, 1, 1]
 | |
|     _, d, h, w, _ = inp.get_shape()
 | |
| 
 | |
|     if not use_bias:
 | |
|         bweight = 0
 | |
| 
 | |
|     if not use_stride:
 | |
|         conv_output = tf.nn.conv3d(inp, cweight, no_stride, 'SAME') + bweight
 | |
|     else:
 | |
|         conv_output = tf.nn.conv3d(inp, cweight, stride, 'SAME') + bweight
 | |
| 
 | |
|     if activation is not None:
 | |
|         conv_output = activation(conv_output, alpha=0.1)
 | |
| 
 | |
|     if bn:
 | |
|         conv_output = batch_norm(conv_output, scale, bias)
 | |
|     if pn:
 | |
|         conv_output = pixel_norm(conv_output)
 | |
|     if gn:
 | |
|         conv_output = group_norm(conv_output, scale, bias)
 | |
|     if ln:
 | |
|         conv_output = layer_norm(conv_output, scale, bias)
 | |
| 
 | |
|     if FLAGS.downsample and use_stride:
 | |
|         conv_output = tf.layers.average_pooling2d(conv_output, (2, 2), 2)
 | |
| 
 | |
|     return conv_output
 | |
| 
 | |
| 
 | |
| def group_norm(inp, scale, bias, g=32, eps=1e-6, stop_batch=False):
 | |
|     """Applies group normalization assuming nhwc format"""
 | |
|     n, h, w, c = inp.shape
 | |
|     inp = tf.reshape(inp, (tf.shape(inp)[0], h, w, c // g, g))
 | |
| 
 | |
|     mean, var = tf.nn.moments(inp, [1, 2, 4], keep_dims=True)
 | |
|     gain = tf.rsqrt(var + eps)
 | |
| 
 | |
|     # if stop_batch:
 | |
|     #     gain = tf.stop_gradient(gain)
 | |
| 
 | |
|     output = gain * (inp - mean)
 | |
|     output = tf.reshape(output, (tf.shape(inp)[0], h, w, c))
 | |
| 
 | |
|     if scale is not None:
 | |
|         output = output * scale
 | |
| 
 | |
|     if bias is not None:
 | |
|         output = output + bias
 | |
| 
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def layer_norm(inp, scale, bias, eps=1e-6):
 | |
|     """Applies group normalization assuming nhwc format"""
 | |
|     n, h, w, c = inp.shape
 | |
| 
 | |
|     mean, var = tf.nn.moments(inp, [1, 2, 3], keep_dims=True)
 | |
|     gain = tf.rsqrt(var + eps)
 | |
|     output = gain * (inp - mean)
 | |
| 
 | |
|     if scale is not None:
 | |
|         output = output * scale
 | |
| 
 | |
|     if bias is not None:
 | |
|         output = output + bias
 | |
| 
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def conv_cond_concat(x, y):
 | |
|     """Concatenate conditioning vector on feature map axis."""
 | |
|     x_shapes = tf.shape(x)
 | |
|     y_shapes = tf.shape(y)
 | |
| 
 | |
|     return tf.concat(
 | |
|         [x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]]) / 10.], 3)
 | |
| 
 | |
| 
 | |
| def attention(
 | |
|         inp,
 | |
|         q,
 | |
|         q_b,
 | |
|         k,
 | |
|         k_b,
 | |
|         v,
 | |
|         v_b,
 | |
|         gamma,
 | |
|         reuse,
 | |
|         scope,
 | |
|         stop_at_grad=False,
 | |
|         seperate=False,
 | |
|         scale=False,
 | |
|         train=False,
 | |
|         dropout=0.0):
 | |
|     conv_q = conv_block(
 | |
|         inp,
 | |
|         q,
 | |
|         q_b,
 | |
|         reuse=reuse,
 | |
|         scope=scope,
 | |
|         use_stride=False,
 | |
|         activation=None,
 | |
|         use_bias=True,
 | |
|         pn=False,
 | |
|         bn=False,
 | |
|         gn=False)
 | |
|     conv_k = conv_block(
 | |
|         inp,
 | |
|         k,
 | |
|         k_b,
 | |
|         reuse=reuse,
 | |
|         scope=scope,
 | |
|         use_stride=False,
 | |
|         activation=None,
 | |
|         use_bias=True,
 | |
|         pn=False,
 | |
|         bn=False,
 | |
|         gn=False)
 | |
| 
 | |
|     conv_v = conv_block(
 | |
|         inp,
 | |
|         v,
 | |
|         v_b,
 | |
|         reuse=reuse,
 | |
|         scope=scope,
 | |
|         use_stride=False,
 | |
|         pn=False,
 | |
|         bn=False,
 | |
|         gn=False)
 | |
| 
 | |
|     c_num = float(conv_q.get_shape().as_list()[-1])
 | |
|     s = tf.matmul(hw_flatten(conv_q), hw_flatten(conv_k), transpose_b=True)
 | |
| 
 | |
|     if scale:
 | |
|         s = s / (c_num) ** 0.5
 | |
| 
 | |
|     if train:
 | |
|         s = tf.nn.dropout(s, 0.9)
 | |
| 
 | |
|     beta = tf.nn.softmax(s, axis=-1)
 | |
|     o = tf.matmul(beta, hw_flatten(conv_v))
 | |
|     o = tf.reshape(o, shape=tf.shape(inp))
 | |
|     inp = inp + gamma * o
 | |
| 
 | |
|     if not seperate:
 | |
|         return inp
 | |
|     else:
 | |
|         return gamma * o
 | |
| 
 | |
| 
 | |
| def attention_2d(
 | |
|         inp,
 | |
|         q,
 | |
|         q_b,
 | |
|         k,
 | |
|         k_b,
 | |
|         v,
 | |
|         v_b,
 | |
|         reuse,
 | |
|         scope,
 | |
|         stop_at_grad=False,
 | |
|         seperate=False,
 | |
|         scale=False):
 | |
|     inp_shape = tf.shape(inp)
 | |
|     inp_compact = tf.reshape(
 | |
|         inp,
 | |
|         (inp_shape[0] *
 | |
|          FLAGS.input_objects *
 | |
|          inp_shape[1],
 | |
|          inp.shape[3]))
 | |
|     f_q = tf.matmul(inp_compact, q) + q_b
 | |
|     f_k = tf.matmul(inp_compact, k) + k_b
 | |
|     f_v = tf.nn.leaky_relu(tf.matmul(inp_compact, v) + v_b)
 | |
| 
 | |
|     f_q = tf.reshape(f_q,
 | |
|                      (inp_shape[0],
 | |
|                       inp_shape[1],
 | |
|                          inp_shape[2],
 | |
|                          tf.shape(f_q)[-1]))
 | |
|     f_k = tf.reshape(f_k,
 | |
|                      (inp_shape[0],
 | |
|                       inp_shape[1],
 | |
|                          inp_shape[2],
 | |
|                          tf.shape(f_k)[-1]))
 | |
|     f_v = tf.reshape(
 | |
|         f_v,
 | |
|         (inp_shape[0],
 | |
|          inp_shape[1],
 | |
|          inp_shape[2],
 | |
|          inp_shape[3]))
 | |
| 
 | |
|     s = tf.matmul(f_k, f_q, transpose_b=True)
 | |
|     c_num = (32**0.5)
 | |
| 
 | |
|     if scale:
 | |
|         s = s / c_num
 | |
| 
 | |
|     beta = tf.nn.softmax(s, axis=-1)
 | |
| 
 | |
|     o = tf.reshape(tf.matmul(beta, f_v), inp_shape) + inp
 | |
| 
 | |
|     return o
 | |
| 
 | |
| 
 | |
| def hw_flatten(x):
 | |
|     shape = tf.shape(x)
 | |
|     return tf.reshape(x, [tf.shape(x)[0], -1, shape[-1]])
 | |
| 
 | |
| 
 | |
| def batch_norm(inp, scale, bias, eps=0.01):
 | |
|     mean, var = tf.nn.moments(inp, [0])
 | |
|     output = tf.nn.batch_normalization(inp, mean, var, bias, scale, eps)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def normalize(inp, activation, reuse, scope):
 | |
|     if FLAGS.norm == 'batch_norm':
 | |
|         return tf_layers.batch_norm(
 | |
|             inp,
 | |
|             activation_fn=activation,
 | |
|             reuse=reuse,
 | |
|             scope=scope)
 | |
|     elif FLAGS.norm == 'layer_norm':
 | |
|         return tf_layers.layer_norm(
 | |
|             inp,
 | |
|             activation_fn=activation,
 | |
|             reuse=reuse,
 | |
|             scope=scope)
 | |
|     elif FLAGS.norm == 'None':
 | |
|         if activation is not None:
 | |
|             return activation(inp)
 | |
|         else:
 | |
|             return inp
 | |
| 
 | |
| # Loss functions
 | |
| 
 | |
| 
 | |
| def mse(pred, label):
 | |
|     pred = tf.reshape(pred, [-1])
 | |
|     label = tf.reshape(label, [-1])
 | |
|     return tf.reduce_mean(tf.square(pred - label))
 | |
| 
 | |
| 
 | |
| NO_OPS = 'NO_OPS'
 | |
| 
 | |
| 
 | |
| def _l2normalize(v, eps=1e-12):
 | |
|     return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
 | |
| 
 | |
| 
 | |
| def spectral_normed_weight(w, name, lower_bound=False, iteration=1, fc=False):
 | |
|     if fc:
 | |
|         iteration = 2
 | |
| 
 | |
|     w_shape = w.shape.as_list()
 | |
|     w = tf.reshape(w, [-1, w_shape[-1]])
 | |
| 
 | |
|     iteration = FLAGS.spec_iter
 | |
|     sigma_new = FLAGS.spec_norm_val
 | |
| 
 | |
|     u = tf.get_variable(name + "_u",
 | |
|                         [1,
 | |
|                          w_shape[-1]],
 | |
|                         initializer=tf.random_normal_initializer(),
 | |
|                         trainable=False)
 | |
| 
 | |
|     u_hat = u
 | |
|     v_hat = None
 | |
|     for i in range(iteration):
 | |
|         """
 | |
|         power iteration
 | |
|         Usually iteration = 1 will be enough
 | |
|         """
 | |
|         v_ = tf.matmul(u_hat, tf.transpose(w))
 | |
|         v_hat = tf.nn.l2_normalize(v_)
 | |
| 
 | |
|         u_ = tf.matmul(v_hat, w)
 | |
|         u_hat = tf.nn.l2_normalize(u_)
 | |
| 
 | |
|     u_hat = tf.stop_gradient(u_hat)
 | |
|     v_hat = tf.stop_gradient(v_hat)
 | |
| 
 | |
|     sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
 | |
| 
 | |
|     if FLAGS.spec_eval:
 | |
|         dep = []
 | |
|     else:
 | |
|         dep = [u.assign(u_hat)]
 | |
| 
 | |
|     with tf.control_dependencies(dep):
 | |
|         if lower_bound:
 | |
|             sigma = sigma + 1e-6
 | |
|             w_norm = w / sigma * tf.minimum(sigma, 1) * sigma_new
 | |
|         else:
 | |
|             w_norm = w / sigma * sigma_new
 | |
| 
 | |
|         w_norm = tf.reshape(w_norm, w_shape)
 | |
| 
 | |
|     return w_norm
 | |
| 
 | |
| 
 | |
| def average_gradients(tower_grads):
 | |
|     """Calculate the average gradient for each shared variable across all towers.
 | |
|     Note that this function provides a synchronization point across all towers.
 | |
|     Args:
 | |
|       tower_grads: List of lists of (gradient, variable) tuples. The outer list
 | |
|         is over individual gradients. The inner list is over the gradient
 | |
|         calculation for each tower.
 | |
|     Returns:
 | |
|        List of pairs of (gradient, variable) where the gradient has been averaged
 | |
|        across all towers.
 | |
|     """
 | |
|     average_grads = []
 | |
|     for grad_and_vars in zip(*tower_grads):
 | |
|         # Note that each grad_and_vars looks like the following:
 | |
|         #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
 | |
|         grads = []
 | |
|         for g, v in grad_and_vars:
 | |
|             if g is not None:
 | |
|                 # Add 0 dimension to the gradients to represent the tower.
 | |
|                 expanded_g = tf.expand_dims(g, 0)
 | |
| 
 | |
|                 # Append on a 'tower' dimension which we will average over
 | |
|                 # below.
 | |
|                 grads.append(expanded_g)
 | |
|             else:
 | |
|                 print(g, v)
 | |
| 
 | |
|         # Average over the 'tower' dimension.
 | |
|         grad = tf.concat(axis=0, values=grads)
 | |
|         grad = tf.reduce_mean(grad, 0)
 | |
| 
 | |
|         # Keep in mind that the Variables are redundant because they are shared
 | |
|         # across towers. So .. we will just return the first tower's pointer to
 | |
|         # the Variable.
 | |
|         v = grad_and_vars[0][1]
 | |
|         grad_and_var = (grad, v)
 | |
|         average_grads.append(grad_and_var)
 | |
|     return average_grads
 | 
