Source code for parts.centaur.conv_block

# Copyright (c) 2019 NVIDIA Corporation
import tensorflow as tf

from .batch_norm import BatchNorm1D


[docs]class ConvBlock: """ Convolutional block for Centaur model. """
[docs] def __init__(self, name, conv, norm, activation_fn, dropout, training, is_residual, is_causal): """ Convolutional block constructor. Args: name: name of the block. conv: convolutional layer. norm: normalization layer to use after the convolutional layer. activation_fn: activation function to use after the normalization. dropout: dropout rate. training: whether it is training mode. is_residual: whether the block should contain a residual connection. is_causal: whether the convolutional layer should be causal. """ self.name = name self.conv = conv self.norm = norm self.activation_fn = activation_fn self.dropout = dropout self.training = training self.is_residual = is_residual self.is_casual = is_causal
def __call__(self, x): with tf.variable_scope(self.name): if self.is_casual: # Add padding from the left side to avoid looking to the future pad_size = self.conv.kernel_size[0] - 1 y = tf.pad(x, [[0, 0], [pad_size, 0], [0, 0]]) else: y = x y = self.conv(y) if self.norm is not None: y = self.norm(y, training=self.training) if self.activation_fn is not None: y = self.activation_fn(y) if self.dropout is not None: y = self.dropout(y, training=self.training) return x + y if self.is_residual else y
[docs] @staticmethod def create(index, conv_params, regularizer, bn_momentum, bn_epsilon, cnn_dropout_prob, training, is_residual=True, is_causal=False): activation_fn = conv_params.get("activation_fn", tf.nn.relu) conv = tf.layers.Conv1D( name="conv_%d" % index, filters=conv_params["num_channels"], kernel_size=conv_params["kernel_size"], strides=conv_params["stride"], padding=conv_params["padding"], kernel_regularizer=regularizer ) norm = BatchNorm1D( name="bn_%d" % index, gamma_regularizer=regularizer, momentum=bn_momentum, epsilon=bn_epsilon ) dropout = tf.layers.Dropout( name="dropout_%d" % index, rate=cnn_dropout_prob ) if "is_causal" in conv_params: is_causal = conv_params["is_causal"] if "is_residual" in conv_params: is_residual = conv_params["is_residual"] return ConvBlock( name="layer_%d" % index, conv=conv, norm=norm, activation_fn=activation_fn, dropout=dropout, training=training, is_residual=is_residual, is_causal=is_causal )