Source code for decoders.rnn_decoders

# Copyright (c) 2018 NVIDIA Corporation
"""
RNN-based decoders.
"""
from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals

import copy

import tensorflow as tf

from open_seq2seq.parts.rnns.attention_wrapper import BahdanauAttention, \
                                                      LuongAttention, \
                                                      AttentionWrapper
from open_seq2seq.parts.rnns.gnmt import GNMTAttentionMultiCell, \
                                         gnmt_residual_fn
from open_seq2seq.parts.rnns.rnn_beam_search_decoder import BeamSearchDecoder
from open_seq2seq.parts.rnns.utils import single_cell
from .decoder import Decoder


[docs]class RNNDecoderWithAttention(Decoder): """Typical RNN decoder with attention mechanism. """
[docs] @staticmethod def get_required_params(): return dict(Decoder.get_required_params(), **{ 'GO_SYMBOL': int, # symbol id 'END_SYMBOL': int, # symbol id 'tgt_vocab_size': int, 'tgt_emb_size': int, 'attention_layer_size': int, 'attention_type': ['bahdanau', 'luong', 'gnmt', 'gnmt_v2'], 'core_cell': None, 'decoder_layers': int, 'decoder_use_skip_connections': bool, 'batch_size': int, })
[docs] @staticmethod def get_optional_params(): return dict(Decoder.get_optional_params(), **{ 'core_cell_params': dict, 'bahdanau_normalize': bool, 'luong_scale': bool, 'decoder_dp_input_keep_prob': float, 'decoder_dp_output_keep_prob': float, 'time_major': bool, 'use_swap_memory': bool, 'proj_size': int, 'num_groups': int, 'PAD_SYMBOL': int, # symbol id 'weight_tied': bool, })
[docs] def __init__(self, params, model, name='rnn_decoder_with_attention', mode='train'): """Initializes RNN decoder with embedding. See parent class for arguments description. Config parameters: * **batch_size** (int) --- batch size. * **GO_SYMBOL** (int) --- GO symbol id, must be the same as used in data layer. * **END_SYMBOL** (int) --- END symbol id, must be the same as used in data layer. * **tgt_emb_size** (int) --- embedding size to use. * **core_cell_params** (dict) - parameters for RNN class * **core_cell** (string) - RNN class. * **decoder_dp_input_keep_prob** (float) - dropout input keep probability. * **decoder_dp_output_keep_prob** (float) - dropout output keep probability. * **decoder_use_skip_connections** (bool) - use residual connections or not. * **attention_type** (string) - bahdanau, luong, gnmt or gnmt_v2. * **bahdanau_normalize** (bool, optional) - whether to use normalization in bahdanau attention. * **luong_scale** (bool, optional) - whether to use scale in luong attention * ... add any cell-specific parameters here as well. """ super(RNNDecoderWithAttention, self).__init__(params, model, name, mode) self._batch_size = self.params['batch_size'] self.GO_SYMBOL = self.params['GO_SYMBOL'] self.END_SYMBOL = self.params['END_SYMBOL'] self._tgt_vocab_size = self.params['tgt_vocab_size'] self._tgt_emb_size = self.params['tgt_emb_size'] self._weight_tied = self.params.get('weight_tied', False)
[docs] def _build_attention(self, encoder_outputs, encoder_sequence_length): """Builds Attention part of the graph. Currently supports "bahdanau" and "luong". """ with tf.variable_scope("AttentionMechanism"): attention_depth = self.params['attention_layer_size'] if self.params['attention_type'] == 'bahdanau': if 'bahdanau_normalize' in self.params: bah_normalize = self.params['bahdanau_normalize'] else: bah_normalize = False attention_mechanism = BahdanauAttention( num_units=attention_depth, memory=encoder_outputs, normalize=bah_normalize, memory_sequence_length=encoder_sequence_length, probability_fn=tf.nn.softmax, dtype=tf.get_variable_scope().dtype ) elif self.params['attention_type'] == 'luong': if 'luong_scale' in self.params: luong_scale = self.params['luong_scale'] else: luong_scale = False attention_mechanism = LuongAttention( num_units=attention_depth, memory=encoder_outputs, scale=luong_scale, memory_sequence_length=encoder_sequence_length, probability_fn=tf.nn.softmax, dtype=tf.get_variable_scope().dtype ) elif self.params['attention_type'] == 'gnmt' or \ self.params['attention_type'] == 'gnmt_v2': attention_mechanism = BahdanauAttention( num_units=attention_depth, memory=encoder_outputs, normalize=True, memory_sequence_length=encoder_sequence_length, probability_fn=tf.nn.softmax, dtype=tf.get_variable_scope().dtype ) else: raise ValueError('Unknown Attention Type') return attention_mechanism
@staticmethod def _add_residual_wrapper(cells, start_ind=1): for idx, cell in enumerate(cells): if idx >= start_ind: cells[idx] = tf.contrib.rnn.ResidualWrapper( # pylint: disable=no-member cell, residual_fn=gnmt_residual_fn, ) return cells
[docs] def _decode(self, input_dict): """Decodes representation into data. Args: input_dict (dict): Python dictionary with inputs to decoder. Config parameters: * **src_inputs** --- Decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * **src_lengths** --- Decoder input lengths Tensor of shape [batch_size] * **tgt_inputs** --- Only during training. labels Tensor of the shape [batch_size, time] or [time, batch_size]. * **tgt_lengths** --- Only during training. labels lengths Tensor of the shape [batch_size]. Returns: dict: Python dictionary with: * final_outputs - tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * final_state - tensor with decoder final state * final_sequence_lengths - tensor of shape [batch_size, time] or [time, batch_size] """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_lengths'] tgt_inputs = input_dict['target_tensors'][0] if 'target_tensors' in \ input_dict else None tgt_lengths = input_dict['target_tensors'][1] if 'target_tensors' in \ input_dict else None self._output_projection_layer = tf.layers.Dense( self._tgt_vocab_size, use_bias=False, ) if not self._weight_tied: self._dec_emb_w = tf.get_variable( name='DecoderEmbeddingMatrix', shape=[self._tgt_vocab_size, self._tgt_emb_size], dtype=tf.float32 ) else: fake_input = tf.zeros(shape=(1, self._tgt_emb_size)) fake_output = self._output_projection_layer.apply(fake_input) with tf.variable_scope("dense", reuse=True): dense_weights = tf.get_variable("kernel") self._dec_emb_w = tf.transpose(dense_weights) if self._mode == "train": dp_input_keep_prob = self.params['decoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['decoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 residual_connections = self.params['decoder_use_skip_connections'] # list of cells cell_params = self.params.get('core_cell_params', {}) self._decoder_cells = [ single_cell( cell_class=self.params['core_cell'], cell_params=self.params.get('core_cell_params', {}), dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, # residual connections are added a little differently for GNMT residual_connections=False if self.params['attention_type'].startswith('gnmt') else residual_connections, ) for _ in range(self.params['decoder_layers'] - 1) ] last_cell_params = copy.deepcopy(cell_params) if self._weight_tied: last_cell_params['num_units'] = self._tgt_emb_size last_cell = single_cell( cell_class=self.params['core_cell'], cell_params=last_cell_params, dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, # residual connections are added a little differently for GNMT residual_connections=False if self.params['attention_type'].startswith('gnmt') else residual_connections, ) self._decoder_cells.append(last_cell) attention_mechanism = self._build_attention( encoder_outputs, enc_src_lengths, ) if self.params['attention_type'].startswith('gnmt'): attention_cell = self._decoder_cells.pop(0) attention_cell = AttentionWrapper( attention_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, output_attention=False, name="gnmt_attention", ) attentive_decoder_cell = GNMTAttentionMultiCell( attention_cell, self._add_residual_wrapper(self._decoder_cells) if residual_connections else self._decoder_cells, use_new_attention=(self.params['attention_type'] == 'gnmt_v2'), ) else: attentive_decoder_cell = AttentionWrapper( # pylint: disable=no-member cell=tf.contrib.rnn.MultiRNNCell(self._decoder_cells), attention_mechanism=attention_mechanism, ) if self._mode == "train": input_vectors = tf.cast( tf.nn.embedding_lookup(self._dec_emb_w, tgt_inputs), dtype=self.params['dtype'], ) helper = tf.contrib.seq2seq.TrainingHelper( # pylint: disable=no-member inputs=input_vectors, sequence_length=tgt_lengths, ) decoder = tf.contrib.seq2seq.BasicDecoder( # pylint: disable=no-member cell=attentive_decoder_cell, helper=helper, output_layer=self._output_projection_layer, initial_state=attentive_decoder_cell.zero_state( self._batch_size, dtype=encoder_outputs.dtype, ), ) elif self._mode == "infer" or self._mode == "eval": embedding_fn = lambda ids: tf.cast( tf.nn.embedding_lookup(self._dec_emb_w, ids), dtype=self.params['dtype'], ) # pylint: disable=no-member helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding=embedding_fn, start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL), end_token=self.END_SYMBOL, ) decoder = tf.contrib.seq2seq.BasicDecoder( # pylint: disable=no-member cell=attentive_decoder_cell, helper=helper, initial_state=attentive_decoder_cell.zero_state( batch_size=self._batch_size, dtype=encoder_outputs.dtype, ), output_layer=self._output_projection_layer, ) else: raise ValueError( "Unknown mode for decoder: {}".format(self._mode) ) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) if self._mode == 'train': maximum_iterations = tf.reduce_max(tgt_lengths) else: maximum_iterations = tf.reduce_max(enc_src_lengths) * 2 # pylint: disable=no-member final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, impute_finished=True, maximum_iterations=maximum_iterations, swap_memory=use_swap_memory, output_time_major=time_major, ) return {'logits': final_outputs.rnn_output if not time_major else tf.transpose(final_outputs.rnn_output, perm=[1, 0, 2]), 'outputs': [tf.argmax(final_outputs.rnn_output, axis=-1)], 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths}
[docs]class BeamSearchRNNDecoderWithAttention(RNNDecoderWithAttention): """ Beam search version of RNN-based decoder with attention. Can be used only during Inference (mode=infer) """
[docs] @staticmethod def get_optional_params(): return dict(RNNDecoderWithAttention.get_optional_params(), **{ 'length_penalty': float, 'beam_width': int, })
[docs] def __init__(self, params, model, name="rnn_decoder_with_attention", mode='train'): """Initializes beam search decoder. Args: params(dict): dictionary with decoder parameters Config parameters: * **batch_size** --- batch size * **GO_SYMBOL** --- GO symbol id, must be the same as used in data layer * **END_SYMBOL** --- END symbol id, must be the same as used in data layer * **tgt_vocab_size** --- vocabulary size of target * **tgt_emb_size** --- embedding to use * **decoder_cell_units** --- number of units in RNN * **decoder_cell_type** --- RNN type: lstm, gru, glstm, etc. * **decoder_dp_input_keep_prob** --- * **decoder_dp_output_keep_prob** --- * **decoder_use_skip_connections** --- use residual connections or not * **attention_type** --- bahdanau, luong, gnmt, gnmt_v2 * **bahdanau_normalize** --- (optional) * **luong_scale** --- (optional) * **mode** --- train or infer ... add any cell-specific parameters here as well """ super(BeamSearchRNNDecoderWithAttention, self).__init__( params, model, name, mode, ) if self._mode != 'infer': raise ValueError( 'BeamSearch decoder only supports infer mode, but got {}'.format( self._mode, ) ) if "length_penalty" not in self.params: self._length_penalty_weight = 0.0 else: self._length_penalty_weight = self.params["length_penalty"] # beam_width of 1 should be same as argmax decoder if "beam_width" not in self.params: self._beam_width = 1 else: self._beam_width = self.params["beam_width"]
[docs] def _decode(self, input_dict): """Decodes representation into data. Args: input_dict (dict): Python dictionary with inputs to decoder Must define: * src_inputs - decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * src_lengths - decoder input lengths Tensor of shape [batch_size] Does not need tgt_inputs and tgt_lengths Returns: dict: a Python dictionary with: * final_outputs - tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * final_state - tensor with decoder final state * final_sequence_lengths - tensor of shape [batch_size, time] or [time, batch_size] """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_lengths'] self._output_projection_layer = tf.layers.Dense( self._tgt_vocab_size, use_bias=False, ) if not self._weight_tied: self._dec_emb_w = tf.get_variable( name='DecoderEmbeddingMatrix', shape=[self._tgt_vocab_size, self._tgt_emb_size], dtype=tf.float32 ) else: fake_input = tf.zeros(shape=(1, self._tgt_emb_size)) fake_output = self._output_projection_layer.apply(fake_input) with tf.variable_scope("dense", reuse=True): dense_weights = tf.get_variable("kernel") self._dec_emb_w = tf.transpose(dense_weights) if self._mode == "train": dp_input_keep_prob = self.params['decoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['decoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 residual_connections = self.params['decoder_use_skip_connections'] # list of cells cell_params = self.params.get('core_cell_params', {}) self._decoder_cells = [ single_cell( cell_class=self.params['core_cell'], cell_params=cell_params, dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, # residual connections are added a little differently for GNMT residual_connections=False if self.params['attention_type'].startswith('gnmt') else residual_connections, ) for _ in range(self.params['decoder_layers'] - 1) ] last_cell_params = copy.deepcopy(cell_params) if self._weight_tied: last_cell_params['num_units'] = self._tgt_emb_size last_cell = single_cell( cell_class=self.params['core_cell'], cell_params=last_cell_params, dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, # residual connections are added a little differently for GNMT residual_connections=False if self.params['attention_type'].startswith('gnmt') else residual_connections, ) self._decoder_cells.append(last_cell) # pylint: disable=no-member tiled_enc_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=self._beam_width, ) # pylint: disable=no-member tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch( enc_src_lengths, multiplier=self._beam_width, ) attention_mechanism = self._build_attention( tiled_enc_outputs, tiled_enc_src_lengths, ) if self.params['attention_type'].startswith('gnmt'): attention_cell = self._decoder_cells.pop(0) attention_cell = AttentionWrapper( attention_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, # don't use attention layer. output_attention=False, name="gnmt_attention", ) attentive_decoder_cell = GNMTAttentionMultiCell( attention_cell, self._add_residual_wrapper(self._decoder_cells) if residual_connections else self._decoder_cells, use_new_attention=(self.params['attention_type'] == 'gnmt_v2') ) else: # non-GNMT attentive_decoder_cell = AttentionWrapper( # pylint: disable=no-member cell=tf.contrib.rnn.MultiRNNCell(self._decoder_cells), attention_mechanism=attention_mechanism, ) batch_size_tensor = tf.constant(self._batch_size) embedding_fn = lambda ids: tf.cast( tf.nn.embedding_lookup(self._dec_emb_w, ids), dtype=self.params['dtype'], ) decoder = BeamSearchDecoder( cell=attentive_decoder_cell, embedding=embedding_fn, start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]), end_token=self.END_SYMBOL, initial_state=attentive_decoder_cell.zero_state( dtype=encoder_outputs.dtype, batch_size=batch_size_tensor * self._beam_width, ), beam_width=self._beam_width, output_layer=self._output_projection_layer, length_penalty_weight=self._length_penalty_weight ) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) final_outputs, final_state, final_sequence_lengths = \ tf.contrib.seq2seq.dynamic_decode( # pylint: disable=no-member decoder=decoder, maximum_iterations=tf.reduce_max(enc_src_lengths) * 2, swap_memory=use_swap_memory, output_time_major=time_major, ) return {'logits': final_outputs.predicted_ids[:, :, 0] if not time_major else tf.transpose(final_outputs.predicted_ids[:, :, 0], perm=[1, 0, 2]), 'outputs': [final_outputs.predicted_ids[:, :, 0]], 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths}