Source code for losses.ctc_loss

# Copyright (c) 2018 NVIDIA Corporation

from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals

import tensorflow as tf

from open_seq2seq.utils.utils import mask_nans, deco_print
from .loss import Loss


[docs]def dense_to_sparse(dense_tensor, sequence_length): indices = tf.where(tf.sequence_mask(sequence_length)) values = tf.gather_nd(dense_tensor, indices) shape = tf.shape(dense_tensor, out_type=tf.int64) return tf.SparseTensor(indices, values, shape)
[docs]class CTCLoss(Loss): """Implementation of the CTC loss."""
[docs] @staticmethod def get_optional_params(): return dict(Loss.get_optional_params(), **{ 'mask_nan': bool, })
[docs] def __init__(self, params, model, name="ctc_loss"): """CTC loss constructor. See parent class for arguments description. Config parameters: * **mask_nan** (bool) --- whether to mask nans in the loss output. Defaults to True. """ super(CTCLoss, self).__init__(params, model, name) self._mask_nan = self.params.get("mask_nan", True) # this loss can only operate in full precision # if self.params['dtype'] != tf.float32: # deco_print("Warning: defaulting CTC loss to work in float32") self.params['dtype'] = tf.float32
[docs] def _compute_loss(self, input_dict): """CTC loss graph construction. Expects the following inputs:: input_dict = { } Args: input_dict (dict): input dictionary that has to contain the following fields:: input_dict = { "decoder_output": { "logits": tensor, shape [batch_size, time length, tgt_vocab_size] "src_length": tensor, shape [batch_size] }, "target_tensors": [ tgt_sequence (shape=[batch_size, time length, num features]), tgt_length (shape=[batch_size]) ] } Returns: averaged CTC loss. """ logits = input_dict['decoder_output']['logits'] tgt_sequence, tgt_length = input_dict['target_tensors'] # this loss needs an access to src_length since they # might get changed in the encoder src_length = input_dict['decoder_output']['src_length'] # Compute the CTC loss total_loss = tf.nn.ctc_loss( labels=dense_to_sparse(tgt_sequence, tgt_length), inputs=logits, sequence_length=src_length, ignore_longer_outputs_than_inputs=True, ) if self._mask_nan: total_loss = mask_nans(total_loss) # Calculate the average loss across the batch avg_loss = tf.reduce_mean(total_loss) return avg_loss