Source code for decoders.centaur_decoder

# Copyright (c) 2019 NVIDIA Corporation
import tensorflow as tf
from tensorflow.python.ops import math_ops

from open_seq2seq.parts.centaur import AttentionBlock
from open_seq2seq.parts.centaur import ConvBlock
from open_seq2seq.parts.centaur import Prenet
from open_seq2seq.parts.transformer import utils
from open_seq2seq.parts.transformer.common import LayerNormalization
from .decoder import Decoder


[docs]class CentaurDecoder(Decoder): """ Centaur decoder that consists of attention blocks followed by convolutional layers. """
[docs] @staticmethod def get_required_params(): return dict(Decoder.get_required_params(), **{ "prenet_layers": int, "prenet_hidden_size": int, "hidden_size": int, "conv_layers": list, "mag_conv_layers": None, "attention_dropout": float, "layer_postprocess_dropout": float })
[docs] @staticmethod def get_optional_params(): return dict(Decoder.get_optional_params(), **{ "prenet_activation_fn": None, "prenet_dropout": float, "prenet_use_inference_dropout": bool, "cnn_dropout_prob": float, "bn_momentum": float, "bn_epsilon": float, "reduction_factor": int, "attention_layers": int, "self_attention_conv_params": dict, "attention_heads": int, "attention_cnn_dropout_prob": float, "window_size": int, "back_step_size": int, "force_layers": list })
[docs] def __init__(self, params, model, name="centaur_decoder", mode="train"): """ Centaur decoder constructor. See parent class for arguments description. Config parameters: * **prenet_layers** (int) --- number of fully-connected layers to use. * **prenet_hidden_size** (int) --- number of units in each pre-net layer. * **hidden_size** (int) --- dimensionality of hidden embeddings. * **conv_layers** (list) --- list with the description of convolutional layers. For example:: "conv_layers": [ { "kernel_size": [5], "stride": [1], "num_channels": 512, "padding": "VALID", "is_causal": True }, { "kernel_size": [5], "stride": [1], "num_channels": 512, "padding": "VALID", "is_causal": True }, { "kernel_size": [5], "stride": [1], "num_channels": 512, "padding": "VALID", "is_causal": True }, { "kernel_size": [5], "stride": [1], "num_channels": 512, "padding": "VALID", "is_causal": True } ] * **mag_conv_layers** (list) --- list with the description of convolutional layers to reconstruct magnitude. * **attention_dropout** (float) --- dropout rate for attention layers. * **layer_postprocess_dropout** (float) --- dropout rate for transformer block sublayers. * **prenet_activation_fn** (callable) --- activation function to use for the prenet lyaers. Defaults to relu. * **prenet_dropout** (float) --- dropout rate for the pre-net. Defaults to 0.5. * **prenet_use_inference_dropout** (bool) --- whether to use dropout during the inference. Defaults to False. * **cnn_dropout_prob** (float) --- dropout probabilty for cnn layers. Defaults to 0.5. * **bn_momentum** (float) --- momentum for batch norm. Defaults to 0.95. * **bn_epsilon** (float) --- epsilon for batch norm. Defaults to 1e-8. * **reduction_factor** (int) --- number of frames to predict in a time. Defaults to 1. * **attention_layers** (int) --- number of attention blocks. Defaults to 4. * **self_attention_conv_params** (dict) --- description of convolutional layer inside attention blocks. Defaults to None. * **attention_heads** (int) --- number of attention heads. Defaults to 1. * **attention_cnn_dropout_prob** (float) --- dropout rate for convolutional layers inside attention blocks. Defaults to 0.5. * **window_size** (int) --- size of attention window for forcing monotonic attention during the inference. Defaults to None. * **back_step_size** (int) --- number of steps attention is allowed to go back during the inference. Defaults to 0. * **force_layers** (list) --- indices of layers where forcing of monotonic attention should be enabled. Defaults to all layers. """ super(CentaurDecoder, self).__init__(params, model, name, mode) data_layer_params = model.get_data_layer().params n_feats = data_layer_params["num_audio_features"] use_mag = "both" in data_layer_params["output_type"] self.training = mode == "train" self.prenet = None self.linear_projection = None self.attentions = [] self.output_normalization = None self.conv_layers = [] self.mag_conv_layers = [] self.stop_token_projection_layer = None self.mel_projection_layer = None self.mag_projection_layer = None self.n_mel = n_feats["mel"] if use_mag else n_feats self.n_mag = n_feats["magnitude"] if use_mag else None self.reduction_factor = params.get("reduction_factor", 1)
def _build_layers(self): regularizer = self._params.get("regularizer", None) inference_dropout = self._params.get("prenet_use_inference_dropout", False) self.prenet = Prenet( n_layers=self._params["prenet_layers"], hidden_size=self._params["prenet_hidden_size"], activation_fn=self._params.get("prenet_activation_fn", tf.nn.relu), dropout=self._params.get("prenet_dropout", 0.5), regularizer=regularizer, training=self.training or inference_dropout, dtype=self._params["dtype"] ) cnn_dropout_prob = self._params.get("cnn_dropout_prob", 0.5) bn_momentum = self._params.get("bn_momentum", 0.95) bn_epsilon = self._params.get("bn_epsilon", -1e8) self.linear_projection = tf.layers.Dense( name="linear_projection", units=self._params["hidden_size"], use_bias=False, kernel_regularizer=regularizer, dtype=self._params["dtype"] ) n_layers = self._params.get("attention_layers", 4) n_heads = self._params.get("attention_heads", 1) conv_params = self._params.get("self_attention_conv_params", None) force_layers = self._params.get("force_layers", range(n_layers)) for index in range(n_layers): window_size = None if index in force_layers: window_size = self._params.get("window_size", None) attention = AttentionBlock( name="attention_block_%d" % index, hidden_size=self._params["hidden_size"], attention_dropout=self._params["attention_dropout"], layer_postprocess_dropout=self._params["layer_postprocess_dropout"], regularizer=regularizer, training=self.training, cnn_dropout_prob=self._params.get("attention_cnn_dropout_prob", 0.5), conv_params=conv_params, n_heads=n_heads, window_size=window_size, back_step_size=self._params.get("back_step_size", None) ) self.attentions.append(attention) self.output_normalization = LayerNormalization(self._params["hidden_size"]) for index, params in enumerate(self._params["conv_layers"]): if params["num_channels"] == -1: params["num_channels"] = self.n_mel * self.reduction_factor layer = ConvBlock.create( index=index, conv_params=params, regularizer=regularizer, bn_momentum=bn_momentum, bn_epsilon=bn_epsilon, cnn_dropout_prob=cnn_dropout_prob, training=self.training ) self.conv_layers.append(layer) for index, params in enumerate(self._params["mag_conv_layers"]): if params["num_channels"] == -1: params["num_channels"] = self.n_mag * self.reduction_factor layer = ConvBlock.create( index=index, conv_params=params, regularizer=regularizer, bn_momentum=bn_momentum, bn_epsilon=bn_epsilon, cnn_dropout_prob=cnn_dropout_prob, training=self.training ) self.mag_conv_layers.append(layer) self.stop_token_projection_layer = tf.layers.Dense( name="stop_token_projection", units=1 * self.reduction_factor, use_bias=True, kernel_regularizer=regularizer ) self.mel_projection_layer = tf.layers.Dense( name="mel_projection", units=self.n_mel * self.reduction_factor, use_bias=True, kernel_regularizer=regularizer ) self.mag_projection_layer = tf.layers.Dense( name="mag_projection", units=self.n_mag * self.reduction_factor, use_bias=True, kernel_regularizer=regularizer ) def _decode(self, input_dict): self._build_layers() if "target_tensors" in input_dict: targets = input_dict["target_tensors"][0] else: targets = None encoder_outputs = input_dict["encoder_output"]["outputs"] attention_bias = input_dict["encoder_output"]["inputs_attention_bias"] spec_length = None if self.mode == "train" or self.mode == "eval": spec_length = None if "target_tensors" in input_dict: spec_length = input_dict["target_tensors"][2] if self.training: return self._train(targets, encoder_outputs, attention_bias, spec_length) return self._infer(encoder_outputs, attention_bias, spec_length) def _decode_pass(self, decoder_inputs, encoder_outputs, enc_dec_attention_bias, sequence_lengths=None, alignment_positions=None): y = self.prenet(decoder_inputs) y = self.linear_projection(y) with tf.variable_scope("decoder_pos_encoding"): pos_encoding = self._positional_encoding(y, self.params["dtype"]) y += pos_encoding with tf.variable_scope("encoder_pos_encoding"): pos_encoding = self._positional_encoding(encoder_outputs, self.params["dtype"]) encoder_outputs += pos_encoding for i, attention in enumerate(self.attentions): positions = None if alignment_positions is not None: positions = alignment_positions[i, :, :, :] y = attention(y, encoder_outputs, enc_dec_attention_bias, positions=positions) y = self.output_normalization(y) with tf.variable_scope("conv_layers"): for layer in self.conv_layers: y = layer(y) stop_token_logits = self.stop_token_projection_layer(y) mel_spec = self.mel_projection_layer(y) with tf.variable_scope("mag_conv_layers"): for layer in self.mag_conv_layers: y = layer(y) mag_spec = self.mag_projection_layer(y) if sequence_lengths is None: batch_size = tf.shape(y)[0] sequence_lengths = tf.zeros([batch_size]) return { "spec": mel_spec, "post_net_spec": mel_spec, "alignments": None, "stop_token_logits": stop_token_logits, "lengths": sequence_lengths, "mag_spec": mag_spec } def _train(self, targets, encoder_outputs, enc_dec_attention_bias, sequence_lengths): # Shift targets to the right, and remove the last element with tf.name_scope("shift_targets"): n_features = self.n_mel + self.n_mag targets = targets[:, :, :n_features] targets = self._shrink(targets, n_features, self.reduction_factor) decoder_inputs = tf.pad(targets, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] outputs = self._decode_pass( decoder_inputs=decoder_inputs, encoder_outputs=encoder_outputs, enc_dec_attention_bias=enc_dec_attention_bias, sequence_lengths=sequence_lengths ) with tf.variable_scope("alignments"): weights = [] for index in range(len(self.attentions)): op = "ForwardPass/centaur_decoder/attention_block_%d/attention/attention/attention_weights" % index weights_operation = tf.get_default_graph().get_operation_by_name(op) weight = weights_operation.values()[0] weights.append(weight) outputs["alignments"] = [tf.stack(weights)] return self._convert_outputs( outputs, self.reduction_factor, self._model.params["batch_size_per_gpu"] ) def _infer(self, encoder_outputs, enc_dec_attention_bias, sequence_lengths): if sequence_lengths is None: maximum_iterations = self._model.get_data_layer()._params.get("duration_max", 1000) else: maximum_iterations = tf.reduce_max(sequence_lengths) maximum_iterations //= self.reduction_factor state, state_shape_invariants = self._inference_initial_state( encoder_outputs, enc_dec_attention_bias ) state = tf.while_loop( cond=self._inference_cond, body=self._inference_step, loop_vars=[state], shape_invariants=state_shape_invariants, back_prop=False, maximum_iterations=maximum_iterations, parallel_iterations=1 ) return self._convert_outputs( state["outputs"], self.reduction_factor, self._model.params["batch_size_per_gpu"] )
[docs] def _inference_initial_state(self, encoder_outputs, encoder_decoder_attention_bias): """Create initial state for inference.""" with tf.variable_scope("inference_initial_state"): batch_size = tf.shape(encoder_outputs)[0] n_layers = self._params.get("attention_layers", 1) n_heads = self._params.get("attention_heads", 1) n_features = self.n_mel + self.n_mag state = { "iteration": tf.constant(0), "inputs": tf.zeros([batch_size, 1, n_features * self.reduction_factor]), "finished": tf.cast(tf.zeros([batch_size]), tf.bool), "alignment_positions": tf.zeros([n_layers, batch_size, n_heads, 1], dtype=tf.int32), "outputs": { "spec": tf.zeros([batch_size, 0, self.n_mel * self.reduction_factor]), "post_net_spec": tf.zeros([batch_size, 0, self.n_mel * self.reduction_factor]), "alignments": [ tf.zeros([0, 0, 0, 0, 0]) ], "stop_token_logits": tf.zeros([batch_size, 0, 1 * self.reduction_factor]), "lengths": tf.zeros([batch_size], dtype=tf.int32), "mag_spec": tf.zeros([batch_size, 0, self.n_mag * self.reduction_factor]) }, "encoder_outputs": encoder_outputs, "encoder_decoder_attention_bias": encoder_decoder_attention_bias } state_shape_invariants = { "iteration": tf.TensorShape([]), "inputs": tf.TensorShape([None, None, n_features * self.reduction_factor]), "finished": tf.TensorShape([None]), "alignment_positions": tf.TensorShape([n_layers, None, n_heads, None]), "outputs": { "spec": tf.TensorShape([None, None, self.n_mel * self.reduction_factor]), "post_net_spec": tf.TensorShape([None, None, self.n_mel * self.reduction_factor]), "alignments": [ tf.TensorShape([None, None, None, None, None]), ], "stop_token_logits": tf.TensorShape([None, None, 1 * self.reduction_factor]), "lengths": tf.TensorShape([None]), "mag_spec": tf.TensorShape([None, None, None]) }, "encoder_outputs": encoder_outputs.shape, "encoder_decoder_attention_bias": encoder_decoder_attention_bias.shape } return state, state_shape_invariants
[docs] def _inference_cond(self, state): """Check if it's time to stop inference.""" with tf.variable_scope("inference_cond"): all_finished = math_ops.reduce_all(state["finished"]) return tf.logical_not(all_finished)
[docs] def _inference_step(self, state): """Make one inference step.""" decoder_inputs = state["inputs"] encoder_outputs = state["encoder_outputs"] enc_dec_attention_bias = state["encoder_decoder_attention_bias"] alignment_positions = state["alignment_positions"] outputs = self._decode_pass( decoder_inputs=decoder_inputs, encoder_outputs=encoder_outputs, enc_dec_attention_bias=enc_dec_attention_bias, alignment_positions=alignment_positions ) with tf.variable_scope("inference_step"): next_inputs_mel = outputs["post_net_spec"][:, -1:, :] next_inputs_mel = self._expand(next_inputs_mel, self.reduction_factor) next_inputs_mag = outputs["mag_spec"][:, -1:, :] next_inputs_mag = self._expand(next_inputs_mag, self.reduction_factor) next_inputs = tf.concat([next_inputs_mel, next_inputs_mag], axis=-1) n_features = self.n_mel + self.n_mag next_inputs = self._shrink(next_inputs, n_features, self.reduction_factor) # Set zero if sequence is finished next_inputs = tf.where( state["finished"], tf.zeros_like(next_inputs), next_inputs ) next_inputs = tf.concat([decoder_inputs, next_inputs], 1) # Update lengths lengths = state["outputs"]["lengths"] lengths = tf.where( state["finished"], lengths, lengths + 1 * self.reduction_factor ) outputs["lengths"] = lengths # Update spec, post_net_spec and mag_spec for key in ["spec", "post_net_spec", "mag_spec"]: output = outputs[key][:, -1:, :] output = tf.where(state["finished"], tf.zeros_like(output), output) outputs[key] = tf.concat([state["outputs"][key], output], 1) # Update stop token logits stop_token_logits = outputs["stop_token_logits"][:, -1:, :] stop_token_logits = tf.where( state["finished"], tf.zeros_like(stop_token_logits) + 1e9, stop_token_logits ) stop_prediction = tf.sigmoid(stop_token_logits) stop_prediction = tf.reduce_max(stop_prediction, axis=-1) # Uncomment next line if you want to use stop token predictions finished = tf.reshape(tf.cast(tf.round(stop_prediction), tf.bool), [-1]) finished = tf.reshape(finished, [-1]) stop_token_logits = tf.concat( [state["outputs"]["stop_token_logits"], stop_token_logits], axis=1 ) outputs["stop_token_logits"] = stop_token_logits with tf.variable_scope("alignments"): forward = "ForwardPass" if self.mode == "infer" else "ForwardPass_1" weights = [] for index in range(len(self.attentions)): op = forward + "/centaur_decoder/while/attention_block_%d/attention/attention/attention_weights" % index weights_operation = tf.get_default_graph().get_operation_by_name(op) weight = weights_operation.values()[0] weights.append(weight) weights = tf.stack(weights) outputs["alignments"] = [weights] alignment_positions = tf.argmax( weights, axis=-1, output_type=tf.int32 )[:, :, :, -1:] state["alignment_positions"] = tf.concat( [state["alignment_positions"], alignment_positions], axis=-1 ) state["iteration"] = state["iteration"] + 1 state["inputs"] = next_inputs state["finished"] = finished state["outputs"] = outputs return state
[docs] @staticmethod def _shrink(values, last_dim, reduction_factor): """Shrink the given input by reduction_factor.""" shape = tf.shape(values) new_shape = [ shape[0], shape[1] // reduction_factor, last_dim * reduction_factor ] values = tf.reshape(values, new_shape) return values
[docs] @staticmethod def _expand(values, reduction_factor): """Expand the given input by reduction_factor.""" shape = tf.shape(values) new_shape = [ shape[0], shape[1] * reduction_factor, shape[2] // reduction_factor ] values = tf.reshape(values, new_shape) return values
[docs] @staticmethod def _positional_encoding(x, dtype): """Add positional encoding to the given input.""" length = tf.shape(x)[1] features_count = tf.shape(x)[2] features_count += features_count % 2 pos_encoding = utils.get_position_encoding(length, features_count) position_encoding = tf.cast(pos_encoding, dtype) position_encoding = position_encoding[:, :features_count] return position_encoding
[docs] @staticmethod def _convert_outputs(outputs, reduction_factor, batch_size): """Convert output of the decoder to appropriate format.""" with tf.variable_scope("output_converter"): for key in ["spec", "post_net_spec", "stop_token_logits", "mag_spec"]: outputs[key] = CentaurDecoder._expand(outputs[key], reduction_factor) alignments = [] for sample in range(batch_size): alignments.append([outputs["alignments"][0][:, sample, :, :, :]]) return { "outputs": [ outputs["spec"], outputs["post_net_spec"], alignments, tf.sigmoid(outputs["stop_token_logits"]), outputs["lengths"], outputs["mag_spec"] ], "stop_token_prediction": outputs["stop_token_logits"] }