# Copyright (c) 2018 NVIDIA Corporation
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from six.moves import range
import tensorflow as tf
from .tcn import tcn
layers_dict = {
"conv1d": tf.layers.conv1d,
"conv2d": tf.layers.conv2d,
"tcn": tcn,
}
[docs]def conv_actv(layer_type, name, inputs, filters, kernel_size, activation_fn,
strides, padding, regularizer, training, data_format, dilation=1):
"""Helper function that applies convolution and activation.
Args:
layer_type: the following types are supported
'conv1d', 'conv2d'
"""
layer = layers_dict[layer_type]
conv = layer(
name="{}".format(name),
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
dilation_rate=dilation,
kernel_regularizer=regularizer,
use_bias=False,
data_format=data_format,
)
output = conv
if activation_fn is not None:
output = activation_fn(output)
return output
[docs]def conv_bn_res_bn_actv(layer_type, name, inputs, res_inputs, filters,
kernel_size, activation_fn, strides, padding,
regularizer, training, data_format, bn_momentum,
bn_epsilon, dilation=1,
drop_block_prob=0.0, drop_block=False):
layer = layers_dict[layer_type]
if not isinstance(res_inputs, list):
res_inputs = [res_inputs]
# For backwards compatibiliaty with earlier models
res_name = "{}/res"
res_bn_name = "{}/res_bn"
else:
res_name = "{}/res_{}"
res_bn_name = "{}/res_bn_{}"
res_aggregation = 0
for i, res in enumerate(res_inputs):
res = layer(
res,
filters,
1,
name=res_name.format(name, i),
use_bias=False,
)
squeeze = False
if layer_type == "conv1d":
axis = 1 if data_format == 'channels_last' else 2
res = tf.expand_dims(res, axis=axis) # NWC --> NHWC
squeeze = True
res = tf.layers.batch_normalization(
name=res_bn_name.format(name, i),
inputs=res,
gamma_regularizer=regularizer,
training=training,
axis=-1 if data_format == 'channels_last' else 1,
momentum=bn_momentum,
epsilon=bn_epsilon,
)
if squeeze:
res = tf.squeeze(res, axis=axis)
res_aggregation += res
conv = layer(
name="{}".format(name),
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
dilation_rate=dilation,
kernel_regularizer=regularizer,
use_bias=False,
data_format=data_format,
)
# trick to make batchnorm work for mixed precision training.
# To-Do check if batchnorm works smoothly for >4 dimensional tensors
squeeze = False
if layer_type == "conv1d":
axis = 1 if data_format == 'channels_last' else 2
conv = tf.expand_dims(conv, axis=axis) # NWC --> NHWC
squeeze = True
bn = tf.layers.batch_normalization(
name="{}/bn".format(name),
inputs=conv,
gamma_regularizer=regularizer,
training=training,
axis=-1 if data_format == 'channels_last' else 1,
momentum=bn_momentum,
epsilon=bn_epsilon,
)
if squeeze:
bn = tf.squeeze(bn, axis=axis)
output = bn + res_aggregation
if drop_block_prob > 0:
if training:
output = tf.cond(
tf.random_uniform(shape=[]) < drop_block_prob,
lambda: res_aggregation,
lambda: bn + res_aggregation
)
elif drop_block:
output = res_aggregation
if activation_fn is not None:
output = activation_fn(output)
return output
[docs]def conv_bn_actv(layer_type, name, inputs, filters, kernel_size, activation_fn,
strides, padding, regularizer, training, data_format,
bn_momentum, bn_epsilon, dilation=1):
"""Helper function that applies convolution, batch norm and activation.
Args:
layer_type: the following types are supported
'conv1d', 'conv2d'
"""
layer = layers_dict[layer_type]
conv = layer(
name="{}".format(name),
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
dilation_rate=dilation,
kernel_regularizer=regularizer,
use_bias=False,
data_format=data_format,
)
# trick to make batchnorm work for mixed precision training.
# To-Do check if batchnorm works smoothly for >4 dimensional tensors
squeeze = False
if layer_type == "conv1d":
axis = 1 if data_format == 'channels_last' else 2
conv = tf.expand_dims(conv, axis=axis) # NWC --> NHWC
squeeze = True
bn = tf.layers.batch_normalization(
name="{}/bn".format(name),
inputs=conv,
gamma_regularizer=regularizer,
training=training,
axis=-1 if data_format == 'channels_last' else 1,
momentum=bn_momentum,
epsilon=bn_epsilon,
)
if squeeze:
bn = tf.squeeze(bn, axis=axis)
output = bn
if activation_fn is not None:
output = activation_fn(output)
return output
[docs]def conv_ln_actv(layer_type, name, inputs, filters, kernel_size, activation_fn,
strides, padding, regularizer, training, data_format,
dilation=1):
"""Helper function that applies convolution, layer norm and activation.
Args:
layer_type: the following types are supported
'conv1d', 'conv2d'
"""
layer = layers_dict[layer_type]
conv = layer(
name="{}".format(name),
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
dilation_rate=dilation,
kernel_regularizer=regularizer,
use_bias=False,
data_format=data_format,
)
if data_format == 'channels_first':
if layer_type == "conv1d":
conv = tf.transpose(conv, [0, 2, 1])
elif layer_type == "conv2d":
conv = tf.transpose(conv, [0, 2, 3, 1])
ln = tf.contrib.layers.layer_norm(
inputs=conv,
)
if data_format == 'channels_first':
if layer_type == "conv1d":
ln = tf.transpose(ln, [0, 2, 1])
elif layer_type == "conv2d":
ln = tf.transpose(ln, [0, 3, 1, 2])
output = ln
if activation_fn is not None:
output = activation_fn(output)
return output
[docs]def conv_in_actv(layer_type, name, inputs, filters, kernel_size, activation_fn,
strides, padding, regularizer, training, data_format,
dilation=1):
"""Helper function that applies convolution, instance norm and activation.
Args:
layer_type: the following types are supported
'conv1d', 'conv2d'
"""
layer = layers_dict[layer_type]
conv = layer(
name="{}".format(name),
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
dilation_rate=dilation,
kernel_regularizer=regularizer,
use_bias=False,
data_format=data_format,
)
sn = tf.contrib.layers.instance_norm(
inputs=conv,
data_format="NHWC" if data_format == 'channels_last' else "NCHW"
)
output = sn
if activation_fn is not None:
output = activation_fn(output)
return output