Source code for decoders.convs2s_decoder

from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals

import tensorflow as tf
import math
from .decoder import Decoder

from open_seq2seq.parts.transformer import beam_search

from open_seq2seq.parts.transformer import embedding_layer
from open_seq2seq.parts.transformer.utils import get_padding

from open_seq2seq.parts.convs2s import ffn_wn_layer, conv_wn_layer, attention_wn_layer
from open_seq2seq.parts.convs2s.utils import gated_linear_units

# Default value used if max_input_length is not given
MAX_INPUT_LENGTH = 128


[docs]class ConvS2SDecoder(Decoder):
[docs] @staticmethod def get_required_params(): """Static method with description of required parameters. Returns: dict: Dictionary containing all the parameters that **have to** be included into the ``params`` parameter of the class :meth:`__init__` method. """ return dict( Decoder.get_required_params(), **{ 'batch_size': int, 'tgt_emb_size': int, 'tgt_vocab_size': int, 'shared_embed': bool, 'embedding_dropout_keep_prob': float, 'conv_nchannels_kwidth': list, 'hidden_dropout_keep_prob': float, 'out_dropout_keep_prob': float, 'beam_size': int, 'alpha': float, 'extra_decode_length': int, 'EOS_ID': int, })
[docs] @staticmethod def get_optional_params(): """Static method with description of optional parameters. Returns: dict: Dictionary containing all the parameters that **can** be included into the ``params`` parameter of the class :meth:`__init__` method. """ return dict( Decoder.get_optional_params(), **{ 'pad_embeddings_2_eight': bool, # set the default to False later. "pos_embed": bool, # if not provided, tgt_emb_size is used as the default value 'out_emb_size': int, 'max_input_length': int, 'GO_SYMBOL': int, 'PAD_SYMBOL': int, 'END_SYMBOL': int, 'conv_activation': None, 'normalization_type': str, 'scaling_factor': float, 'init_var': None, })
def _cast_types(self, input_dict): return input_dict def __init__(self, params, model, name="convs2s_decoder", mode='train'): super(ConvS2SDecoder, self).__init__(params, model, name, mode) self.embedding_softmax_layer = None self.position_embedding_layer = None self.layers = [] self._tgt_vocab_size = self.params['tgt_vocab_size'] self._tgt_emb_size = self.params['tgt_emb_size'] self._mode = mode self._pad_sym = self.params.get('PAD_SYMBOL', 0) self._pad2eight = params.get('pad_embeddings_2_eight', False) self.scaling_factor = self.params.get("scaling_factor", math.sqrt(0.5)) self.normalization_type = self.params.get("normalization_type", "weight_norm") self.conv_activation = self.params.get("conv_activation", gated_linear_units) self.max_input_length = self.params.get("max_input_length", MAX_INPUT_LENGTH) self.init_var = self.params.get('init_var', None) self.regularizer = self.params.get('regularizer', None) def _decode(self, input_dict): targets = input_dict['target_tensors'][0] \ if 'target_tensors' in input_dict else None encoder_outputs = input_dict['encoder_output']['outputs'] encoder_outputs_b = input_dict['encoder_output'].get( 'outputs_b', encoder_outputs) inputs_attention_bias = input_dict['encoder_output'].get( 'inputs_attention_bias_cs2s', None) with tf.name_scope("decode"): # prepare decoder layers if len(self.layers) == 0: knum_list = list(zip(*self.params.get("conv_nchannels_kwidth")))[0] kwidth_list = list(zip(*self.params.get("conv_nchannels_kwidth")))[1] # preparing embedding layers with tf.variable_scope("embedding"): if 'embedding_softmax_layer' in input_dict['encoder_output'] \ and self.params['shared_embed']: self.embedding_softmax_layer = \ input_dict['encoder_output']['embedding_softmax_layer'] else: self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights( vocab_size=self._tgt_vocab_size, hidden_size=self._tgt_emb_size, pad_vocab_to_eight=self._pad2eight, init_var=0.1, embed_scale=False, pad_sym=self._pad_sym, mask_paddings=True) if self.params.get("pos_embed", True): with tf.variable_scope("pos_embedding"): if 'position_embedding_layer' in input_dict['encoder_output'] \ and self.params['shared_embed']: self.position_embedding_layer = \ input_dict['encoder_output']['position_embedding_layer'] else: self.position_embedding_layer = embedding_layer.EmbeddingSharedWeights( vocab_size=self.max_input_length, hidden_size=self._tgt_emb_size, pad_vocab_to_eight=self._pad2eight, init_var=0.1, embed_scale=False, pad_sym=self._pad_sym, mask_paddings=True) else: self.position_embedding_layer = None # linear projection before cnn layers self.layers.append( ffn_wn_layer.FeedFowardNetworkNormalized( self._tgt_emb_size, knum_list[0], dropout=self.params["embedding_dropout_keep_prob"], var_scope_name="linear_mapping_before_cnn_layers", mode=self.mode, normalization_type=self.normalization_type, regularizer=self.regularizer, init_var=self.init_var) ) for i in range(len(knum_list)): in_dim = knum_list[i] if i == 0 else knum_list[i - 1] out_dim = knum_list[i] # linear projection is needed for residual connections if # input and output of a cnn layer do not match if in_dim != out_dim: linear_proj = ffn_wn_layer.FeedFowardNetworkNormalized( in_dim, out_dim, var_scope_name="linear_mapping_cnn_" + str(i + 1), dropout=1.0, mode=self.mode, normalization_type=self.normalization_type, regularizer = self.regularizer, init_var = self.init_var, ) else: linear_proj = None conv_layer = conv_wn_layer.Conv1DNetworkNormalized( in_dim, out_dim, kernel_width=kwidth_list[i], mode=self.mode, layer_id=i + 1, hidden_dropout=self.params["hidden_dropout_keep_prob"], conv_padding="VALID", decode_padding=True, activation=self.conv_activation, normalization_type=self.normalization_type, regularizer=self.regularizer, init_var=self.init_var ) att_layer = attention_wn_layer.AttentionLayerNormalized( out_dim, embed_size=self._tgt_emb_size, layer_id=i + 1, add_res=True, mode=self.mode, normalization_type=self.normalization_type, scaling_factor=self.scaling_factor, regularizer=self.regularizer, init_var=self.init_var ) self.layers.append([linear_proj, conv_layer, att_layer]) # linear projection after cnn layers self.layers.append( ffn_wn_layer.FeedFowardNetworkNormalized( knum_list[-1], self.params.get("out_emb_size", self._tgt_emb_size), dropout=1.0, var_scope_name="linear_mapping_after_cnn_layers", mode=self.mode, normalization_type=self.normalization_type, regularizer=self.regularizer, init_var=self.init_var)) if not self.params['shared_embed']: self.layers.append( ffn_wn_layer.FeedFowardNetworkNormalized( self.params.get("out_emb_size", self._tgt_emb_size), self._tgt_vocab_size, dropout=self.params["out_dropout_keep_prob"], var_scope_name="linear_mapping_to_vocabspace", mode=self.mode, normalization_type=self.normalization_type, regularizer=self.regularizer, init_var=self.init_var)) else: # if embedding is shared, # the shared embedding is used as the final linear projection to vocab space self.layers.append(None) if targets is None: return self.predict(encoder_outputs, encoder_outputs_b, inputs_attention_bias) else: logits = self.decode_pass(targets, encoder_outputs, encoder_outputs_b, inputs_attention_bias) return { "logits": logits, "outputs": [tf.argmax(logits, axis=-1)], "final_state": None, "final_sequence_lengths": None }
[docs] def decode_pass(self, targets, encoder_outputs, encoder_outputs_b, inputs_attention_bias): """Generate logits for each value in the target sequence. Args: targets: target values for the output sequence. int tensor with shape [batch_size, target_length] encoder_outputs: continuous representation of input sequence. float tensor with shape [batch_size, input_length, hidden_size] float tensor with shape [batch_size, input_length, hidden_size] encoder_outputs_b: continuous representation of input sequence which includes the source embeddings. float tensor with shape [batch_size, input_length, hidden_size] inputs_attention_bias: float tensor with shape [batch_size, 1, input_length] Returns: float32 tensor with shape [batch_size, target_length, vocab_size] """ # Prepare inputs to decoder layers by applying embedding # and adding positional encoding. decoder_inputs = self.embedding_softmax_layer(targets) if self.position_embedding_layer is not None: with tf.name_scope("add_pos_encoding"): pos_input = tf.range( 0, tf.shape(decoder_inputs)[1], delta=1, dtype=tf.int32, name='range') pos_encoding = self.position_embedding_layer(pos_input) decoder_inputs = decoder_inputs + tf.cast( x=pos_encoding, dtype=decoder_inputs.dtype) if self.mode == "train": decoder_inputs = tf.nn.dropout(decoder_inputs, self.params["embedding_dropout_keep_prob"]) # mask the paddings in the target inputs_padding = get_padding( targets, padding_value=self._pad_sym, dtype=decoder_inputs.dtype) decoder_inputs *= tf.expand_dims(1.0 - inputs_padding, 2) # do decode logits = self._call( decoder_inputs=decoder_inputs, encoder_outputs_a=encoder_outputs, encoder_outputs_b=encoder_outputs_b, input_attention_bias=inputs_attention_bias) return logits
def _call(self, decoder_inputs, encoder_outputs_a, encoder_outputs_b, input_attention_bias): # run input into the decoder layers and returns the logits target_embed = decoder_inputs with tf.variable_scope("linear_layer_before_cnn_layers"): outputs = self.layers[0](decoder_inputs) for i in range(1, len(self.layers) - 2): linear_proj, conv_layer, att_layer = self.layers[i] with tf.variable_scope("layer_%d" % i): if linear_proj is not None: res_inputs = linear_proj(outputs) else: res_inputs = outputs with tf.variable_scope("conv_layer"): outputs = conv_layer(outputs) with tf.variable_scope("attention_layer"): outputs = att_layer(outputs, target_embed, encoder_outputs_a, encoder_outputs_b, input_attention_bias) outputs = (outputs + res_inputs) * self.scaling_factor with tf.variable_scope("linear_layer_after_cnn_layers"): outputs = self.layers[-2](outputs) if self.mode == "train": outputs = tf.nn.dropout(outputs, self.params["out_dropout_keep_prob"]) with tf.variable_scope("pre_softmax_projection"): if self.layers[-1] is None: logits = self.embedding_softmax_layer.linear(outputs) else: logits = self.layers[-1](outputs) return tf.cast(logits, dtype=tf.float32)
[docs] def predict(self, encoder_outputs, encoder_outputs_b, inputs_attention_bias): """Return predicted sequence.""" batch_size = tf.shape(encoder_outputs)[0] input_length = tf.shape(encoder_outputs)[1] max_decode_length = input_length + self.params["extra_decode_length"] symbols_to_logits_fn = self._get_symbols_to_logits_fn() # Create initial set of IDs that will be passed into symbols_to_logits_fn. initial_ids = tf.zeros( [batch_size], dtype=tf.int32) + self.params["GO_SYMBOL"] cache = {} # Add encoder outputs and attention bias to the cache. cache["encoder_outputs"] = encoder_outputs cache["encoder_outputs_b"] = encoder_outputs_b if inputs_attention_bias is not None: cache["inputs_attention_bias"] = inputs_attention_bias # Use beam search to find the top beam_size sequences and scores. decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache, vocab_size=self.params["tgt_vocab_size"], beam_size=self.params["beam_size"], alpha=self.params["alpha"], max_decode_length=max_decode_length, eos_id=self.params["EOS_ID"]) # Get the top sequence for each batch element top_decoded_ids = decoded_ids[:, 0, :] top_scores = scores[:, 0] # this isn't particularly efficient logits = self.decode_pass(top_decoded_ids, encoder_outputs, encoder_outputs_b, inputs_attention_bias) return { "logits": logits, "outputs": [top_decoded_ids], "final_state": None, "final_sequence_lengths": None }
[docs] def _get_symbols_to_logits_fn(self): """Returns a decoding function that calculates logits of the next tokens.""" def symbols_to_logits_fn(ids, i, cache): """Generate logits for next potential IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i - 1] i: Loop index cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ # pass the decoded ids from the beginneing up to the current into the decoder # not efficient decoder_outputs = self.decode_pass(ids, cache.get("encoder_outputs"), cache.get("encoder_outputs_b"), cache.get("inputs_attention_bias")) logits = decoder_outputs[:, i, :] return logits, cache return symbols_to_logits_fn