ml-ai-agents-py/EBMs/models.py

1555 lines
40 KiB
Python

import numpy as np
import tensorflow as tf
from tensorflow.python.platform import flags
from utils import (
conv_cond_concat,
groupsort,
init_attention_weight,
init_conv_weight,
init_fc_weight,
init_res_weight,
smart_atten_block,
smart_conv_block,
smart_fc_block,
smart_res_block,
swish,
)
flags.DEFINE_bool("swish_act", False, "use the swish activation for dsprites")
FLAGS = flags.FLAGS
class MnistNet(object):
def __init__(self, num_channels=1, num_filters=64):
self.channels = num_channels
self.dim_hidden = num_filters
self.datasource = FLAGS.datasource
if FLAGS.cclass:
self.label_size = 10
else:
self.label_size = 0
def construct_weights(self, scope=""):
weights = {}
dtype = tf.float32
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
classes = 1
with tf.variable_scope(scope):
init_conv_weight(weights, "c1_pre", 3, 1, 64)
init_conv_weight(weights, "c1", 4, 64, self.dim_hidden, classes=classes)
init_conv_weight(
weights, "c2", 4, self.dim_hidden, 2 * self.dim_hidden, classes=classes
)
init_conv_weight(
weights,
"c3",
4,
2 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_fc_weight(
weights,
"fc_dense",
4 * 4 * 4 * self.dim_hidden,
2 * self.dim_hidden,
spec_norm=True,
)
init_fc_weight(weights, "fc5", 2 * self.dim_hidden, 1, spec_norm=False)
if FLAGS.cclass:
self.label_size = 10
else:
self.label_size = 0
return weights
def forward(
self, inp, weights, reuse=False, scope="", stop_grad=False, label=None, **kwargs
):
self.channels
weights = weights.copy()
inp = tf.reshape(inp, (tf.shape(inp)[0], 28, 28, 1))
if FLAGS.swish_act:
act = swish
else:
act = tf.nn.leaky_relu
if stop_grad:
for k, v in weights.items():
if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
if FLAGS.cclass:
label_d = tf.reshape(
label, shape=(tf.shape(label)[0], 1, 1, self.label_size)
)
inp = conv_cond_concat(inp, label_d)
h1 = smart_conv_block(
inp, weights, reuse, "c1_pre", use_stride=False, activation=act
)
h2 = smart_conv_block(
h1,
weights,
reuse,
"c1",
use_stride=True,
downsample=True,
label=label,
extra_bias=False,
activation=act,
)
h3 = smart_conv_block(
h2,
weights,
reuse,
"c2",
use_stride=True,
downsample=True,
label=label,
extra_bias=False,
activation=act,
)
h4 = smart_conv_block(
h3,
weights,
reuse,
"c3",
use_stride=True,
downsample=True,
label=label,
use_scale=False,
extra_bias=False,
activation=act,
)
h5 = tf.reshape(h4, [-1, np.prod([int(dim) for dim in h4.get_shape()[1:]])])
h6 = act(smart_fc_block(h5, weights, reuse, "fc_dense"))
hidden6 = smart_fc_block(h6, weights, reuse, "fc5")
return hidden6
class DspritesNet(object):
def __init__(
self,
num_channels=1,
num_filters=64,
cond_size=False,
cond_shape=False,
cond_pos=False,
cond_rot=False,
label_size=1,
):
self.channels = num_channels
self.dim_hidden = num_filters
self.img_size = 64
self.label_size = label_size
if FLAGS.cclass:
self.label_size = 3
try:
if FLAGS.dshape_only:
self.label_size = 3
if FLAGS.dpos_only:
self.label_size = 2
if FLAGS.dsize_only:
self.label_size = 1
if FLAGS.drot_only:
self.label_size = 2
except BaseException:
pass
if cond_size:
self.label_size = 1
if cond_shape:
self.label_size = 3
if cond_pos:
self.label_size = 2
if cond_rot:
self.label_size = 2
self.cond_size = cond_size
self.cond_shape = cond_shape
self.cond_pos = cond_pos
def construct_weights(self, scope=""):
weights = {}
dtype = tf.float32
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
classes = self.label_size
with tf.variable_scope(scope):
init_conv_weight(weights, "c1_pre", 3, 1, 32)
init_conv_weight(weights, "c1", 4, 32, self.dim_hidden, classes=classes)
init_conv_weight(
weights, "c2", 4, self.dim_hidden, 2 * self.dim_hidden, classes=classes
)
init_conv_weight(
weights,
"c3",
4,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_conv_weight(
weights,
"c4",
4,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_fc_weight(
weights,
"fc_dense",
2 * 4 * 4 * self.dim_hidden,
2 * self.dim_hidden,
spec_norm=True,
)
init_fc_weight(weights, "fc5", 2 * self.dim_hidden, 1, spec_norm=False)
return weights
def forward(
self,
inp,
weights,
reuse=False,
scope="",
stop_grad=False,
label=None,
stop_at_grad=False,
stop_batch=False,
return_logit=False,
):
self.channels
batch_size = tf.shape(inp)[0]
inp = tf.reshape(inp, (batch_size, 64, 64, 1))
if FLAGS.swish_act:
act = swish
else:
act = tf.nn.leaky_relu
if not FLAGS.cclass:
label = None
weights = weights.copy()
if stop_grad:
for k, v in weights.items():
if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
h1 = smart_conv_block(
inp, weights, reuse, "c1_pre", use_stride=False, activation=act
)
h2 = smart_conv_block(
h1,
weights,
reuse,
"c1",
use_stride=True,
downsample=True,
label=label,
extra_bias=True,
activation=act,
)
h3 = smart_conv_block(
h2,
weights,
reuse,
"c2",
use_stride=True,
downsample=True,
label=label,
extra_bias=True,
activation=act,
)
h4 = smart_conv_block(
h3,
weights,
reuse,
"c3",
use_stride=True,
downsample=True,
label=label,
use_scale=True,
extra_bias=True,
activation=act,
)
h5 = smart_conv_block(
h4,
weights,
reuse,
"c4",
use_stride=True,
downsample=True,
label=label,
extra_bias=True,
activation=act,
)
hidden6 = tf.reshape(h5, (tf.shape(h5)[0], -1))
hidden7 = act(smart_fc_block(hidden6, weights, reuse, "fc_dense"))
energy = smart_fc_block(hidden7, weights, reuse, "fc5")
if return_logit:
return hidden7
else:
return energy
class ResNet32(object):
def __init__(self, num_channels=3, num_filters=128):
self.channels = num_channels
self.dim_hidden = num_filters
self.groupsort = groupsort()
def construct_weights(self, scope=""):
weights = {}
tf.float32
if FLAGS.cclass:
classes = 10
else:
classes = 1
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, "c1_pre", 3, self.channels, self.dim_hidden)
init_res_weight(
weights,
"res_optim",
3,
self.dim_hidden,
self.dim_hidden,
classes=classes,
)
init_res_weight(
weights, "res_1", 3, self.dim_hidden, self.dim_hidden, classes=classes
)
init_res_weight(
weights,
"res_2",
3,
self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_3",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_4",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_5",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_fc_weight(
weights, "fc_dense", 4 * 4 * 2 * self.dim_hidden, 4 * self.dim_hidden
)
init_fc_weight(weights, "fc5", 2 * self.dim_hidden, 1, spec_norm=False)
init_attention_weight(
weights,
"atten",
2 * self.dim_hidden,
self.dim_hidden / 2,
trainable_gamma=True,
)
return weights
def forward(
self,
inp,
weights,
reuse=False,
scope="",
stop_grad=False,
label=None,
stop_at_grad=False,
stop_batch=False,
):
weights = weights.copy()
tf.shape(inp)[0]
act = tf.nn.leaky_relu
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
# Make sure gradients are modified a bit
inp = smart_conv_block(inp, weights, reuse, "c1_pre", use_stride=False)
hidden1 = smart_res_block(
inp, weights, reuse, "res_optim", adaptive=False, label=label, act=act
)
hidden2 = smart_res_block(
hidden1,
weights,
reuse,
"res_1",
stop_batch=stop_batch,
downsample=False,
adaptive=False,
label=label,
act=act,
)
hidden3 = smart_res_block(
hidden2,
weights,
reuse,
"res_2",
stop_batch=stop_batch,
label=label,
act=act,
)
if FLAGS.use_attention:
hidden4 = smart_atten_block(
hidden3, weights, reuse, "atten", stop_at_grad=stop_at_grad, label=label
)
else:
hidden4 = smart_res_block(
hidden3,
weights,
reuse,
"res_3",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
act=act,
)
hidden5 = smart_res_block(
hidden4,
weights,
reuse,
"res_4",
stop_batch=stop_batch,
adaptive=False,
label=label,
act=act,
)
compact = hidden6 = smart_res_block(
hidden5,
weights,
reuse,
"res_5",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
)
hidden6 = tf.nn.relu(hidden6)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
return energy
class ResNet32Large(object):
def __init__(self, num_channels=3, num_filters=128, train=False):
self.channels = num_channels
self.dim_hidden = num_filters
self.dropout = train
self.train = train
def construct_weights(self, scope=""):
weights = {}
tf.float32
if FLAGS.cclass:
classes = 10
else:
classes = 1
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, "c1_pre", 3, self.channels, self.dim_hidden)
init_res_weight(
weights,
"res_optim",
3,
self.dim_hidden,
self.dim_hidden,
classes=classes,
)
init_res_weight(
weights, "res_1", 3, self.dim_hidden, self.dim_hidden, classes=classes
)
init_res_weight(
weights, "res_2", 3, self.dim_hidden, self.dim_hidden, classes=classes
)
init_res_weight(
weights,
"res_3",
3,
self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_4",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_5",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_6",
3,
2 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_7",
3,
4 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_8",
3,
4 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_fc_weight(weights, "fc5", 4 * self.dim_hidden, 1, spec_norm=False)
init_attention_weight(
weights,
"atten",
2 * self.dim_hidden,
self.dim_hidden,
trainable_gamma=True,
)
return weights
def forward(
self,
inp,
weights,
reuse=False,
scope="",
stop_grad=False,
label=None,
stop_at_grad=False,
stop_batch=False,
):
weights = weights.copy()
tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
# Make sure gradients are modified a bit
inp = smart_conv_block(inp, weights, reuse, "c1_pre", use_stride=False)
dropout = self.dropout
train = self.train
hidden1 = smart_res_block(
inp,
weights,
reuse,
"res_optim",
adaptive=False,
label=label,
dropout=dropout,
train=train,
)
hidden2 = smart_res_block(
hidden1,
weights,
reuse,
"res_1",
stop_batch=stop_batch,
downsample=False,
adaptive=False,
label=label,
dropout=dropout,
train=train,
)
hidden3 = smart_res_block(
hidden2,
weights,
reuse,
"res_2",
stop_batch=stop_batch,
downsample=False,
adaptive=False,
label=label,
dropout=dropout,
train=train,
)
hidden4 = smart_res_block(
hidden3,
weights,
reuse,
"res_3",
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
)
if FLAGS.use_attention:
hidden5 = smart_atten_block(
hidden4, weights, reuse, "atten", stop_at_grad=stop_at_grad
)
else:
hidden5 = smart_res_block(
hidden4,
weights,
reuse,
"res_4",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
)
hidden6 = smart_res_block(
hidden5,
weights,
reuse,
"res_5",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
)
hidden7 = smart_res_block(
hidden6,
weights,
reuse,
"res_6",
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
)
hidden8 = smart_res_block(
hidden7,
weights,
reuse,
"res_7",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
)
compact = hidden9 = smart_res_block(
hidden8,
weights,
reuse,
"res_8",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
)
if FLAGS.cclass:
hidden6 = tf.nn.leaky_relu(hidden9)
else:
hidden6 = tf.nn.relu(hidden9)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
return energy
class ResNet32Wider(object):
def __init__(self, num_channels=3, num_filters=128, train=False):
self.channels = num_channels
self.dim_hidden = num_filters
self.dropout = train
self.train = train
def construct_weights(self, scope=""):
weights = {}
tf.float32
if FLAGS.cclass and FLAGS.dataset == "cifar10":
classes = 10
elif FLAGS.cclass and FLAGS.dataset == "imagenet":
classes = 1000
else:
classes = 1
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, "c1_pre", 3, self.channels, 128)
init_res_weight(
weights, "res_optim", 3, 128, self.dim_hidden, classes=classes
)
init_res_weight(
weights, "res_1", 3, self.dim_hidden, self.dim_hidden, classes=classes
)
init_res_weight(
weights, "res_2", 3, self.dim_hidden, self.dim_hidden, classes=classes
)
init_res_weight(
weights,
"res_3",
3,
self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_4",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_5",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_6",
3,
2 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_7",
3,
4 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_8",
3,
4 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_fc_weight(weights, "fc5", 4 * self.dim_hidden, 1, spec_norm=False)
init_attention_weight(
weights,
"atten",
self.dim_hidden,
self.dim_hidden / 2,
trainable_gamma=True,
)
return weights
def forward(
self,
inp,
weights,
reuse=False,
scope="",
stop_grad=False,
label=None,
stop_at_grad=False,
stop_batch=False,
):
weights = weights.copy()
tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
if FLAGS.swish_act:
act = swish
else:
act = tf.nn.leaky_relu
# Make sure gradients are modified a bit
inp = smart_conv_block(
inp, weights, reuse, "c1_pre", use_stride=False, activation=act
)
dropout = self.dropout
train = self.train
hidden1 = smart_res_block(
inp,
weights,
reuse,
"res_optim",
adaptive=True,
label=label,
dropout=dropout,
train=train,
)
if FLAGS.use_attention:
hidden2 = smart_atten_block(
hidden1,
weights,
reuse,
"atten",
train=train,
dropout=dropout,
stop_at_grad=stop_at_grad,
)
else:
hidden2 = smart_res_block(
hidden1,
weights,
reuse,
"res_1",
stop_batch=stop_batch,
downsample=False,
adaptive=False,
label=label,
dropout=dropout,
train=train,
act=act,
)
hidden3 = smart_res_block(
hidden2,
weights,
reuse,
"res_2",
stop_batch=stop_batch,
downsample=False,
adaptive=False,
label=label,
dropout=dropout,
train=train,
act=act,
)
hidden4 = smart_res_block(
hidden3,
weights,
reuse,
"res_3",
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
act=act,
)
hidden5 = smart_res_block(
hidden4,
weights,
reuse,
"res_4",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
act=act,
)
hidden6 = smart_res_block(
hidden5,
weights,
reuse,
"res_5",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
act=act,
)
hidden7 = smart_res_block(
hidden6,
weights,
reuse,
"res_6",
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
act=act,
)
hidden8 = smart_res_block(
hidden7,
weights,
reuse,
"res_7",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
act=act,
)
hidden9 = smart_res_block(
hidden8,
weights,
reuse,
"res_8",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
act=act,
)
if FLAGS.swish_act:
hidden6 = act(hidden9)
else:
hidden6 = tf.nn.relu(hidden9)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
return energy
class ResNet32Larger(object):
def __init__(self, num_channels=3, num_filters=128):
self.channels = num_channels
self.dim_hidden = num_filters
def construct_weights(self, scope=""):
weights = {}
tf.float32
if FLAGS.cclass:
classes = 10
else:
classes = 1
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, "c1_pre", 3, self.channels, self.dim_hidden)
init_res_weight(
weights,
"res_optim",
3,
self.dim_hidden,
self.dim_hidden,
classes=classes,
)
init_res_weight(
weights, "res_1", 3, self.dim_hidden, self.dim_hidden, classes=classes
)
init_res_weight(
weights, "res_2", 3, self.dim_hidden, self.dim_hidden, classes=classes
)
init_res_weight(
weights, "res_2a", 3, self.dim_hidden, self.dim_hidden, classes=classes
)
init_res_weight(
weights, "res_2b", 3, self.dim_hidden, self.dim_hidden, classes=classes
)
init_res_weight(
weights,
"res_3",
3,
self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_4",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_5",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_5a",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_5b",
3,
2 * self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_6",
3,
2 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_7",
3,
4 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_8",
3,
4 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_8a",
3,
4 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_8b",
3,
4 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_fc_weight(
weights, "fc_dense", 4 * 4 * 2 * self.dim_hidden, 4 * self.dim_hidden
)
init_fc_weight(weights, "fc5", 4 * self.dim_hidden, 1, spec_norm=False)
init_attention_weight(
weights,
"atten",
2 * self.dim_hidden,
self.dim_hidden / 2,
trainable_gamma=True,
)
return weights
def forward(
self,
inp,
weights,
reuse=False,
scope="",
stop_grad=False,
label=None,
stop_at_grad=False,
stop_batch=False,
):
weights = weights.copy()
tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
# Make sure gradients are modified a bit
inp = smart_conv_block(inp, weights, reuse, "c1_pre", use_stride=False)
hidden1 = smart_res_block(
inp, weights, reuse, "res_optim", adaptive=False, label=label
)
hidden2 = smart_res_block(
hidden1,
weights,
reuse,
"res_1",
stop_batch=stop_batch,
downsample=False,
adaptive=False,
label=label,
)
hidden3 = smart_res_block(
hidden2,
weights,
reuse,
"res_2",
stop_batch=stop_batch,
downsample=False,
adaptive=False,
label=label,
)
hidden3 = smart_res_block(
hidden3,
weights,
reuse,
"res_2a",
stop_batch=stop_batch,
downsample=False,
adaptive=False,
label=label,
)
hidden3 = smart_res_block(
hidden3,
weights,
reuse,
"res_2b",
stop_batch=stop_batch,
downsample=False,
adaptive=False,
label=label,
)
hidden4 = smart_res_block(
hidden3, weights, reuse, "res_3", stop_batch=stop_batch, label=label
)
if FLAGS.use_attention:
hidden5 = smart_atten_block(
hidden4, weights, reuse, "atten", stop_at_grad=stop_at_grad
)
else:
hidden5 = smart_res_block(
hidden4,
weights,
reuse,
"res_4",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
)
hidden6 = smart_res_block(
hidden5,
weights,
reuse,
"res_5",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
)
hidden6 = smart_res_block(
hidden6,
weights,
reuse,
"res_5a",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
)
hidden6 = smart_res_block(
hidden6,
weights,
reuse,
"res_5b",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
)
hidden7 = smart_res_block(
hidden6, weights, reuse, "res_6", stop_batch=stop_batch, label=label
)
hidden8 = smart_res_block(
hidden7,
weights,
reuse,
"res_7",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
)
hidden9 = smart_res_block(
hidden8,
weights,
reuse,
"res_8",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
)
hidden9 = smart_res_block(
hidden9,
weights,
reuse,
"res_8a",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
)
compact = hidden9 = smart_res_block(
hidden9,
weights,
reuse,
"res_8b",
adaptive=False,
downsample=False,
stop_batch=stop_batch,
label=label,
)
if FLAGS.cclass:
hidden6 = tf.nn.leaky_relu(hidden9)
else:
hidden6 = tf.nn.relu(hidden9)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
return energy
class ResNet128(object):
"""Construct the convolutional network specified in MAML"""
def __init__(self, num_channels=3, num_filters=64, train=False):
self.channels = num_channels
self.dim_hidden = num_filters
self.dropout = train
self.train = train
def construct_weights(self, scope=""):
weights = {}
tf.float32
classes = 1000
with tf.variable_scope(scope):
# First block
init_conv_weight(weights, "c1_pre", 3, self.channels, 64)
init_res_weight(
weights, "res_optim", 3, 64, self.dim_hidden, classes=classes
)
init_res_weight(
weights,
"res_3",
3,
self.dim_hidden,
2 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_5",
3,
2 * self.dim_hidden,
4 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_7",
3,
4 * self.dim_hidden,
8 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_9",
3,
8 * self.dim_hidden,
8 * self.dim_hidden,
classes=classes,
)
init_res_weight(
weights,
"res_10",
3,
8 * self.dim_hidden,
8 * self.dim_hidden,
classes=classes,
)
init_fc_weight(weights, "fc5", 8 * self.dim_hidden, 1, spec_norm=False)
init_attention_weight(
weights,
"atten",
self.dim_hidden,
self.dim_hidden / 2.0,
trainable_gamma=True,
)
return weights
def forward(
self,
inp,
weights,
reuse=False,
scope="",
stop_grad=False,
label=None,
stop_at_grad=False,
stop_batch=False,
):
weights = weights.copy()
tf.shape(inp)[0]
if not FLAGS.cclass:
label = None
if stop_grad:
for k, v in weights.items():
if isinstance(v, dict):
v = v.copy()
weights[k] = v
for k_sub, v_sub in v.items():
v[k_sub] = tf.stop_gradient(v_sub)
else:
weights[k] = tf.stop_gradient(v)
if FLAGS.swish_act:
act = swish
else:
act = tf.nn.leaky_relu
dropout = self.dropout
train = self.train
# Make sure gradients are modified a bit
inp = smart_conv_block(
inp, weights, reuse, "c1_pre", use_stride=False, activation=act
)
hidden1 = smart_res_block(
inp,
weights,
reuse,
"res_optim",
label=label,
dropout=dropout,
train=train,
downsample=True,
adaptive=False,
)
if FLAGS.use_attention:
hidden1 = smart_atten_block(
hidden1, weights, reuse, "atten", stop_at_grad=stop_at_grad
)
hidden2 = smart_res_block(
hidden1,
weights,
reuse,
"res_3",
stop_batch=stop_batch,
downsample=True,
adaptive=True,
label=label,
dropout=dropout,
train=train,
act=act,
)
hidden3 = smart_res_block(
hidden2,
weights,
reuse,
"res_5",
stop_batch=stop_batch,
downsample=True,
adaptive=True,
label=label,
dropout=dropout,
train=train,
act=act,
)
hidden4 = smart_res_block(
hidden3,
weights,
reuse,
"res_7",
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
act=act,
downsample=True,
adaptive=True,
)
hidden5 = smart_res_block(
hidden4,
weights,
reuse,
"res_9",
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
act=act,
downsample=True,
adaptive=False,
)
hidden6 = smart_res_block(
hidden5,
weights,
reuse,
"res_10",
stop_batch=stop_batch,
label=label,
dropout=dropout,
train=train,
act=act,
downsample=False,
adaptive=False,
)
if FLAGS.swish_act:
hidden6 = act(hidden6)
else:
hidden6 = tf.nn.relu(hidden6)
hidden5 = tf.reduce_sum(hidden6, [1, 2])
hidden6 = smart_fc_block(hidden5, weights, reuse, "fc5")
energy = hidden6
return energy