Source code for decoders.transformer_decoder
# This code is heavily based on the code from MLPerf
# https://github.com/mlperf/reference/tree/master/translation/tensorflow/transformer
from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals
import tensorflow as tf
from six.moves import range
from open_seq2seq.parts.transformer import utils, attention_layer, \
ffn_layer, beam_search
from open_seq2seq.parts.transformer.common import PrePostProcessingWrapper, \
LayerNormalization, Transformer_BatchNorm
from .decoder import Decoder
[docs]class TransformerDecoder(Decoder):
[docs] @staticmethod
def get_required_params():
"""Static method with description of required parameters.
Returns:
dict:
Dictionary containing all the parameters that **have to** be
included into the ``params`` parameter of the
class :meth:`__init__` method.
"""
return dict(Decoder.get_required_params(), **{
'EOS_ID': int,
'layer_postprocess_dropout': float,
'num_hidden_layers': int,
'hidden_size': int,
'num_heads': int,
'attention_dropout': float,
'relu_dropout': float,
'filter_size': int,
'batch_size': int,
'tgt_vocab_size': int,
'beam_size': int,
'alpha': float,
'extra_decode_length': int,
})
[docs] @staticmethod
def get_optional_params():
"""Static method with description of optional parameters.
Returns:
dict:
Dictionary containing all the parameters that **can** be
included into the ``params`` parameter of the
class :meth:`__init__` method.
"""
return dict(Decoder.get_optional_params(), **{
'regularizer': None, # any valid TensorFlow regularizer
'regularizer_params': dict,
'initializer': None, # any valid TensorFlow initializer
'initializer_params': dict,
'GO_SYMBOL': int,
'PAD_SYMBOL': int,
'END_SYMBOL': int,
'norm_params': dict,
})
def _cast_types(self, input_dict):
return input_dict
def __init__(self, params, model,
name="transformer_decoder", mode='train'):
super(TransformerDecoder, self).__init__(params, model, name, mode)
self.embedding_softmax_layer = None
self.output_normalization = None
self._mode = mode
self.layers = []
# in original T paper embeddings are shared between encoder and decoder
# also final projection = transpose(E_weights), we currently only support
# this behaviour
self.params['shared_embed'] = True
self.norm_params = self.params.get("norm_params", {"type": "layernorm_L2" })
self.regularizer = self.params.get("regularizer", None)
if self.regularizer != None:
self.regularizer_params = params.get("regularizer_params", {'scale': 0.0})
self.regularizer=self.regularizer(self.regularizer_params['scale']) \
if self.regularizer_params['scale'] > 0.0 else None
#print("reg", self.regularizer)
def _decode(self, input_dict):
if 'target_tensors' in input_dict:
targets = input_dict['target_tensors'][0]
else:
targets = None
encoder_outputs = input_dict['encoder_output']['outputs']
inputs_attention_bias = (
input_dict['encoder_output']['inputs_attention_bias']
)
self.embedding_softmax_layer = (
input_dict['encoder_output']['embedding_softmax_layer']
)
with tf.name_scope("decode"):
training = (self.mode == "train")
# prepare decoder layers
if len(self.layers) == 0:
for _ in range(self.params["num_hidden_layers"]):
self_attention_layer = attention_layer.SelfAttention(
hidden_size=self.params["hidden_size"],
num_heads=self.params["num_heads"],
attention_dropout=self.params["attention_dropout"],
train=training,
regularizer=self.regularizer
)
enc_dec_attention_layer = attention_layer.Attention(
hidden_size=self.params["hidden_size"],
num_heads=self.params["num_heads"],
attention_dropout=self.params["attention_dropout"],
train=training,
regularizer=self.regularizer
)
feed_forward_network = ffn_layer.FeedFowardNetwork(
hidden_size=self.params["hidden_size"],
filter_size=self.params["filter_size"],
relu_dropout=self.params["relu_dropout"],
train=training,
regularizer=self.regularizer
)
self.layers.append([
PrePostProcessingWrapper(self_attention_layer, self.params,
training),
PrePostProcessingWrapper(enc_dec_attention_layer, self.params,
training),
PrePostProcessingWrapper(feed_forward_network, self.params,
training)
])
print("Decoder:", self.norm_params["type"], self.mode)
if self.norm_params["type"] == "batch_norm":
self.output_normalization = Transformer_BatchNorm(
training=training,
params=self.norm_params)
else:
self.output_normalization = LayerNormalization(
hidden_size=self.params["hidden_size"],
params=self.norm_params)
if targets is None:
return self.predict(encoder_outputs, inputs_attention_bias)
else:
logits = self.decode_pass(targets, encoder_outputs,
inputs_attention_bias)
return {"logits": logits,
"outputs": [tf.argmax(logits, axis=-1)],
"final_state": None,
"final_sequence_lengths": None}
def _call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias,
attention_bias, cache=None):
for n, layer in enumerate(self.layers):
self_attention_layer = layer[0]
enc_dec_attention_layer = layer[1]
feed_forward_network = layer[2]
# Run inputs through the sublayers.
layer_name = "layer_%d" % n
layer_cache = cache[layer_name] if cache is not None else None
with tf.variable_scope(layer_name):
with tf.variable_scope("self_attention"):
# TODO: Figure out why this is needed
# decoder_self_attention_bias = tf.cast(x=decoder_self_attention_bias,
# dtype=decoder_inputs.dtype)
decoder_inputs = self_attention_layer(
decoder_inputs, decoder_self_attention_bias, cache=layer_cache,
)
with tf.variable_scope("encdec_attention"):
decoder_inputs = enc_dec_attention_layer(
decoder_inputs, encoder_outputs, attention_bias,
)
with tf.variable_scope("ffn"):
decoder_inputs = feed_forward_network(decoder_inputs)
return self.output_normalization(decoder_inputs)
[docs] def decode_pass(self, targets, encoder_outputs, inputs_attention_bias):
"""Generate logits for each value in the target sequence.
Args:
targets: target values for the output sequence.
int tensor with shape [batch_size, target_length]
encoder_outputs: continuous representation of input sequence.
float tensor with shape [batch_size, input_length, hidden_size]
inputs_attention_bias: float tensor with shape [batch_size, 1, 1, input_length]
Returns:
float32 tensor with shape [batch_size, target_length, vocab_size]
"""
# Prepare inputs to decoder layers by shifting targets, adding positional
# encoding and applying dropout.
decoder_inputs = self.embedding_softmax_layer(targets)
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(
decoder_inputs, [[0, 0], [1, 0], [0, 0]],
)[:, :-1, :]
with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1]
# decoder_inputs += utils.get_position_encoding(
# length, self.params["hidden_size"])
decoder_inputs += tf.cast(
utils.get_position_encoding(length, self.params["hidden_size"]),
dtype=self.params['dtype'],
)
if self.mode == "train":
decoder_inputs = tf.nn.dropout(decoder_inputs,
keep_prob = 1 - self.params["layer_postprocess_dropout"] )
# Run values
decoder_self_attention_bias = utils.get_decoder_self_attention_bias(length,
dtype = tf.float32
# dtype=self._params["dtype"]
)
# do decode
outputs = self._call(
decoder_inputs=decoder_inputs,
encoder_outputs=encoder_outputs,
decoder_self_attention_bias=decoder_self_attention_bias,
attention_bias=inputs_attention_bias,
)
logits = self.embedding_softmax_layer.linear(outputs)
return logits
[docs] def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = utils.get_position_encoding(
max_decode_length + 1, self.params["hidden_size"],
)
decoder_self_attention_bias = utils.get_decoder_self_attention_bias(
max_decode_length, dtype = tf.float32
# dtype=self._params["dtype"]
)
def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences.
int tensor with shape [batch_size * beam_size, i + 1]
i: Loop index
cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
Returns:
Tuple of
(logits with shape [batch_size * beam_size, vocab_size],
updated cache values)
"""
# Set decoder input to the last generated IDs
decoder_input = ids[:, -1:]
# Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_softmax_layer(decoder_input)
decoder_input += tf.cast(x=timing_signal[i:i + 1],
dtype=decoder_input.dtype)
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_outputs = self._call(
decoder_input, cache.get("encoder_outputs"), self_attention_bias,
cache.get("encoder_decoder_attention_bias"), cache,
)
logits = self.embedding_softmax_layer.linear(decoder_outputs)
logits = tf.squeeze(logits, axis=[1])
return tf.cast(logits, tf.float32), cache
return symbols_to_logits_fn
[docs] def predict(self, encoder_outputs, encoder_decoder_attention_bias):
"""Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"]
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
# Create initial set of IDs that will be passed into symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer.
cache = {
"layer_%d" % layer: {
"k": tf.zeros([batch_size, 0,
self.params["hidden_size"]],
dtype=encoder_outputs.dtype),
"v": tf.zeros([batch_size, 0,
self.params["hidden_size"]],
dtype=encoder_outputs.dtype),
} for layer in range(self.params["num_hidden_layers"])
}
# Add encoder output and attention bias to the cache.
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=initial_ids,
initial_cache=cache,
vocab_size=self.params["tgt_vocab_size"],
beam_size=self.params["beam_size"],
alpha=self.params["alpha"],
max_decode_length=max_decode_length,
eos_id=self.params["EOS_ID"],
)
# Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:]
# this isn't particularly efficient
logits = self.decode_pass(top_decoded_ids, encoder_outputs,
encoder_decoder_attention_bias)
return {"logits": logits,
"outputs": [top_decoded_ids],
"final_state": None,
"final_sequence_lengths": None}