# Copyright (c) 2019 NVIDIA Corporation
from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals
import tensorflow as tf
from .loss import Loss
[docs]class Text2SpeechLoss(Loss):
"""
Default text-to-speech loss.
"""
[docs] @staticmethod
def get_optional_params():
return {
"use_mask": bool,
"scale": float,
"stop_token_weight": float,
"mel_weight": float,
"mag_weight": float,
"l1_norm": bool
}
def __init__(self, params, model, name="text2speech_loss"):
super(Text2SpeechLoss, self).__init__(params, model, name)
self._n_feats = self._model.get_data_layer().params["num_audio_features"]
if "both" in self._model.get_data_layer().params["output_type"]:
self._both = True
else:
self._both = False
[docs] def _compute_loss(self, input_dict):
"""
Computes loss for text-to-speech model.
Args:
input_dict (dict):
* "decoder_output": dictionary containing:
"outputs": array containing:
* mel: mel-spectrogram predicted by the decoder [batch, time, n_mel]
* post_net_mel: spectrogram after adding the residual
corrections from the post net of shape [batch, time, feats]
* mag: mag-spectrogram predicted by the decoder [batch, time, n_mag]
"stop_token_predictions": stop_token predictions of shape [batch, time, 1]
* "target_tensors": array containing:
* spec: the true spectrogram of shape [batch, time, feats]
* stop_token: the stop_token of shape [batch, time]
* spec_length: the length of specs [batch]
Returns:
Singleton loss tensor
"""
decoder_predictions = input_dict["decoder_output"]["outputs"][0]
post_net_predictions = input_dict["decoder_output"]["outputs"][1]
stop_token_predictions = input_dict["decoder_output"]["stop_token_prediction"]
if self._both:
mag_pred = input_dict["decoder_output"]["outputs"][5]
mag_pred = tf.cast(mag_pred, dtype=tf.float32)
spec = input_dict["target_tensors"][0]
stop_token = input_dict["target_tensors"][1]
stop_token = tf.expand_dims(stop_token, -1)
spec_lengths = input_dict["target_tensors"][2]
batch_size = tf.shape(spec)[0]
num_feats = tf.shape(spec)[2]
decoder_predictions = tf.cast(decoder_predictions, dtype=tf.float32)
post_net_predictions = tf.cast(post_net_predictions, dtype=tf.float32)
stop_token_predictions = tf.cast(stop_token_predictions, dtype=tf.float32)
spec = tf.cast(spec, dtype=tf.float32)
stop_token = tf.cast(stop_token, dtype=tf.float32)
max_length = tf.cast(
tf.maximum(
tf.shape(spec)[1],
tf.shape(decoder_predictions)[1],
), tf.int32
)
decoder_pad = tf.zeros(
[
batch_size,
max_length - tf.shape(decoder_predictions)[1],
tf.shape(decoder_predictions)[2]
]
)
stop_token_pred_pad = tf.zeros(
[batch_size, max_length - tf.shape(decoder_predictions)[1], 1]
)
spec_pad = tf.zeros([batch_size, max_length - tf.shape(spec)[1], num_feats])
stop_token_pad = tf.ones([batch_size, max_length - tf.shape(spec)[1], 1])
decoder_predictions = tf.concat(
[decoder_predictions, decoder_pad],
axis=1
)
post_net_predictions = tf.concat(
[post_net_predictions, decoder_pad],
axis=1
)
stop_token_predictions = tf.concat(
[stop_token_predictions, stop_token_pred_pad],
axis=1
)
spec = tf.concat([spec, spec_pad], axis=1)
stop_token = tf.concat([stop_token, stop_token_pad], axis=1)
if self.params.get("l1_norm", False):
loss_f = tf.losses.absolute_difference
else:
loss_f = tf.losses.mean_squared_error
if self._both:
mag_pad = tf.zeros(
[
batch_size,
max_length - tf.shape(mag_pred)[1],
tf.shape(mag_pred)[2]
]
)
mag_pred = tf.concat(
[mag_pred, mag_pad],
axis=1
)
spec, mag_target = tf.split(
spec,
[self._n_feats["mel"], self._n_feats["magnitude"]],
axis=2
)
decoder_target = spec
post_net_target = spec
if self.params.get("use_mask", True):
mask = tf.sequence_mask(
lengths=spec_lengths,
maxlen=max_length,
dtype=tf.float32
)
mask = tf.expand_dims(mask, axis=-1)
decoder_loss = loss_f(
labels=decoder_target,
predictions=decoder_predictions,
weights=mask
)
post_net_loss = loss_f(
labels=post_net_target,
predictions=post_net_predictions,
weights=mask
)
if self._both:
mag_loss = loss_f(
labels=mag_target,
predictions=mag_pred,
weights=mask
)
stop_token_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=stop_token,
logits=stop_token_predictions
)
stop_token_loss = stop_token_loss * mask
stop_token_loss = tf.reduce_sum(stop_token_loss) / tf.reduce_sum(mask)
else:
decoder_loss = loss_f(
labels=decoder_target,
predictions=decoder_predictions
)
post_net_loss = loss_f(
labels=post_net_target,
predictions=post_net_predictions
)
if self._both:
mag_loss = loss_f(
labels=mag_target,
predictions=mag_pred
)
stop_token_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=stop_token,
logits=stop_token_predictions
)
stop_token_loss = tf.reduce_mean(stop_token_loss)
mel_weight = self.params.get("mel_weight", 1.0)
decoder_loss = mel_weight * decoder_loss
post_net_loss = mel_weight * post_net_loss
stop_token_weight = self.params.get("stop_token_weight", 1.0)
stop_token_loss = stop_token_weight * stop_token_loss
loss = decoder_loss + post_net_loss + stop_token_loss
if self._both:
mag_weight = self.params.get("mag_weight", 1.0)
loss += mag_weight * mag_loss
if self.params.get("scale", None):
loss = loss * self.params["scale"]
return loss