Source code for losses.cross_entropy_loss

# Copyright (c) 2018 NVIDIA Corporation

from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals

import tensorflow as tf

from .loss import Loss


[docs]class CrossEntropyLoss(Loss): """Implementation of the usual cross_entropy loss with softmax.""" def __init__(self, params, model, name="cross_entropy_loss"): super(CrossEntropyLoss, self).__init__(params, model, name) def _compute_loss(self, input_dict): logits = input_dict['decoder_output']['logits'] labels = input_dict['target_tensors'][0] return tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels)