Source code for decoders.fc_decoders

# Copyright (c) 2018 NVIDIA Corporation
"""This module defines various fully-connected decoders (consisting of one
fully connected layer).

These classes are usually used for models that are not really
sequence-to-sequence and thus should be artificially split into encoder and
decoder by cutting, for example, on the last fully-connected layer.
"""
from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals

import os

import tensorflow as tf

from .decoder import Decoder


[docs]class FullyConnectedDecoder(Decoder): """Simple decoder consisting of one fully-connected layer. """
[docs] @staticmethod def get_required_params(): return dict(Decoder.get_required_params(), **{ 'output_dim': int, })
[docs] def __init__(self, params, model, name="fully_connected_decoder", mode='train'): """Fully connected decoder constructor. See parent class for arguments description. Config parameters: * **output_dim** (int) --- output dimension. """ super(FullyConnectedDecoder, self).__init__(params, model, name, mode)
[docs] def _decode(self, input_dict): """This method performs linear transformation of input. Args: input_dict (dict): input dictionary that has to contain the following fields:: input_dict = { 'encoder_output': { 'outputs': output of encoder (shape=[batch_size, num_features]) } } Returns: dict: dictionary with the following tensors:: { 'logits': logits with the shape=[batch_size, output_dim] 'outputs': [logits] (same as logits but wrapped in list) } """ inputs = input_dict['encoder_output']['outputs'] regularizer = self.params.get('regularizer', None) # activation is linear by default logits = tf.layers.dense( inputs=inputs, units=self.params['output_dim'], kernel_regularizer=regularizer, name='fully_connected', ) return {'logits': logits, 'outputs': [logits]}
[docs]class FullyConnectedTimeDecoder(Decoder): """Fully connected decoder that operates on inputs with time dimension. That is, input shape should be ``[batch size, time length, num features]``. """
[docs] @staticmethod def get_required_params(): return dict(Decoder.get_required_params(), **{ 'tgt_vocab_size': int, })
[docs] @staticmethod def get_optional_params(): return dict(Decoder.get_optional_params(), **{ 'logits_to_outputs_func': None, # user defined function 'infer_logits_to_pickle': bool, })
[docs] def __init__(self, params, model, name="fully_connected_time_decoder", mode='train'): """Fully connected time decoder constructor. See parent class for arguments description. Config parameters: * **tgt_vocab_size** (int) --- target vocabulary size, i.e. number of output features. * **logits_to_outputs_func** --- function that maps produced logits to decoder outputs, i.e. actual text sequences. """ super(FullyConnectedTimeDecoder, self).__init__(params, model, name, mode)
[docs] def _decode(self, input_dict): """Creates TensorFlow graph for fully connected time decoder. Args: input_dict (dict): input dictionary that has to contain the following fields:: input_dict = { 'encoder_output': { "outputs": tensor with shape [batch_size, time length, hidden dim] "src_length": tensor with shape [batch_size] } } Returns: dict: dictionary with the following tensors:: { 'logits': logits with the shape=[time length, batch_size, tgt_vocab_size] 'outputs': logits_to_outputs_func(logits, input_dict) } """ inputs = input_dict['encoder_output']['outputs'] regularizer = self.params.get('regularizer', None) batch_size, _, n_hidden = inputs.get_shape().as_list() # reshape from [B, T, A] --> [B*T, A]. # Output shape: [n_steps * batch_size, n_hidden] inputs = tf.reshape(inputs, [-1, n_hidden]) # activation is linear by default logits = tf.layers.dense( inputs=inputs, units=self.params['tgt_vocab_size'], kernel_regularizer=regularizer, name='fully_connected', ) logits = tf.reshape( logits, [batch_size, -1, self.params['tgt_vocab_size']], name="logits", ) # converting to time_major=True shape if not(self._mode=='infer' and self.params.get('infer_logits_to_pickle')): logits = tf.transpose(logits, [1, 0, 2]) if 'logits_to_outputs_func' in self.params: outputs = self.params['logits_to_outputs_func'](logits, input_dict) return { 'outputs': outputs, 'logits': logits, 'src_length': input_dict['encoder_output']['src_length'], } return {'logits': logits, 'src_length': input_dict['encoder_output']['src_length']}
[docs]class FullyConnectedCTCDecoder(FullyConnectedTimeDecoder): """Fully connected time decoder that provides a CTC-based text generation (either with or without language model). If language model is not used, ``tf.nn.ctc_greedy_decoder`` will be used as text generation method. """
[docs] @staticmethod def get_required_params(): return dict(FullyConnectedTimeDecoder.get_required_params(), **{ 'use_language_model': bool, })
[docs] @staticmethod def get_optional_params(): return dict(FullyConnectedTimeDecoder.get_optional_params(), **{ 'decoder_library_path': str, 'beam_width': int, 'alpha': float, 'beta': float, 'trie_weight': float, 'lm_path': str, 'trie_path': str, 'alphabet_config_path': str, })
[docs] def __init__(self, params, model, name="fully_connected_ctc_decoder", mode='train'): """Fully connected CTC decoder constructor. See parent class for arguments description. Config parameters: * **use_language_model** (bool) --- whether to use language model for output text generation. If False, other config parameters are not used. * **decoder_library_path** (string) --- path to the ctc decoder with language model library. * **lm_path** (string) --- path to the language model file. * **trie_path** (string) --- path to the prefix trie file. * **alphabet_config_path** (string) --- path to the alphabet file. * **beam_width** (int) --- beam width for beam search. * **alpha** (float) --- weight that is assigned to language model probabilities. * **beta** (float) --- weight that is assigned to the word count. * **trie_weight** (float) --- weight for prefix tree vocabulary based character level rescoring. """ super(FullyConnectedCTCDecoder, self).__init__(params, model, name, mode) if self.params['use_language_model']: # creating decode_with_lm function if it is compiled lib_path = self.params['decoder_library_path'] if not os.path.exists(os.path.abspath(lib_path)): raise IOError('Can\'t find the decoder with language model library. ' 'Make sure you have built it and ' 'check that you provide the correct ' 'path in the --decoder_library_path parameter.') custom_op_module = tf.load_op_library(lib_path) def decode_with_lm(logits, decoder_input, beam_width=self.params['beam_width'], top_paths=1, merge_repeated=False): sequence_length = decoder_input['encoder_output']['src_length'] if logits.dtype.base_dtype != tf.float32: logits = tf.cast(logits, tf.float32) decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( custom_op_module.ctc_beam_search_decoder_with_lm( logits, sequence_length, beam_width=beam_width, model_path=self.params['lm_path'], trie_path=self.params['trie_path'], alphabet_path=self.params['alphabet_config_path'], alpha=self.params['alpha'], beta=self.params['beta'], trie_weight=self.params.get('trie_weight', 0.1), top_paths=top_paths, merge_repeated=merge_repeated, ) ) return [tf.SparseTensor(decoded_ixs[0], decoded_vals[0], decoded_shapes[0])] self.params['logits_to_outputs_func'] = decode_with_lm else: def decode_without_lm(logits, decoder_input, merge_repeated=True): if logits.dtype.base_dtype != tf.float32: logits = tf.cast(logits, tf.float32) decoded, neg_sum_logits = tf.nn.ctc_greedy_decoder( logits, decoder_input['encoder_output']['src_length'], merge_repeated, ) return decoded self.params['logits_to_outputs_func'] = decode_without_lm
[docs]class FullyConnectedSCDecoder(Decoder): """Fully connected decoder constructor for speech commands. """
[docs] @staticmethod def get_required_params(): return dict(Decoder.get_required_params(), **{ 'output_dim': int, })
[docs] def __init__(self, params, model, name="fully_connected_decoder", mode='train'): """Fully connected decoder constructor. See parent class for arguments description. Config parameters: * **output_dim** (int) --- output dimension. """ super(FullyConnectedSCDecoder, self).__init__(params, model, name, mode)
[docs] def _decode(self, input_dict): """This method performs linear transformation of input. Args: input_dict (dict): input dictionary that has to contain the following fields:: input_dict = { 'encoder_output': { 'outputs': output of encoder (shape=[batch_size, num_features]) } } Returns: dict: dictionary with the following tensors:: { 'logits': logits with the shape=[batch_size, output_dim] 'outputs': [logits] (same as logits but wrapped in list) } """ inputs = input_dict['encoder_output']['outputs'] lengths = input_dict['encoder_output']['src_length'] regularizer = self.params.get('regularizer', None) inputs = tf.layers.flatten(inputs=inputs) # activation is linear by default logits = tf.layers.dense( inputs=inputs, units=self.params['output_dim'], kernel_regularizer=regularizer, name='fully_connected', ) return {'logits': logits, 'outputs': [logits]}