Source code for parts.rnns.rnn_beam_search_decoder

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A decoder that performs beam search."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from six.moves import range

import collections
import numpy as np
import tensorflow as tf

from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest

__all__ = [
    "BeamSearchDecoderOutput",
    "BeamSearchDecoderState",
    "BeamSearchDecoder",
    "FinalBeamSearchDecoderOutput",
    "tile_batch",
]


[docs]class BeamSearchDecoderState( collections.namedtuple("BeamSearchDecoderState", ("cell_state", "log_probs", "finished", "lengths"))): pass
[docs]class BeamSearchDecoderOutput( collections.namedtuple("BeamSearchDecoderOutput", ("scores", "predicted_ids", "parent_ids"))): pass
[docs]class FinalBeamSearchDecoderOutput( collections.namedtuple("FinalBeamDecoderOutput", ["predicted_ids", "beam_search_decoder_output"])): """Final outputs returned by the beam search after all decoding is finished. Args: predicted_ids: The final prediction. A tensor of shape `[batch_size, T, beam_width]` (or `[T, batch_size, beam_width]` if `output_time_major` is True). Beams are ordered from best to worst. beam_search_decoder_output: An instance of `BeamSearchDecoderOutput` that describes the state of the beam search. """ pass
def _tile_batch(t, multiplier): """Core single-tensor implementation of tile_batch.""" t = ops.convert_to_tensor(t, name="t") shape_t = array_ops.shape(t) if t.shape.ndims is None or t.shape.ndims < 1: raise ValueError("t must have statically known rank") tiling = [1] * (t.shape.ndims + 1) tiling[1] = multiplier tiled_static_batch_size = ( t.shape[0].value * multiplier if t.shape[0].value is not None else None) tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) tiled = array_ops.reshape(tiled, array_ops.concat( ([shape_t[0] * multiplier], shape_t[1:]), 0)) tiled.set_shape( tensor_shape.TensorShape([tiled_static_batch_size]).concatenate( t.shape[1:])) return tiled
[docs]def tile_batch(t, multiplier, name=None): """Tile the batch dimension of a (possibly nested structure of) tensor(s) t. For each tensor t in a (possibly nested structure) of tensors, this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated `multiplier` times. Args: t: `Tensor` shaped `[batch_size, ...]`. multiplier: Python int. name: Name scope for any created operations. Returns: A (possibly nested structure of) `Tensor` shaped `[batch_size * multiplier, ...]`. Raises: ValueError: if tensor(s) `t` do not have a statically known rank or the rank is < 1. """ flat_t = nest.flatten(t) with ops.name_scope(name, "tile_batch", flat_t + [multiplier]): return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
def _check_maybe(t): if isinstance(t, tensor_array_ops.TensorArray): raise TypeError( "TensorArray state is not supported by BeamSearchDecoder: %s" % t.name) if t.shape.ndims is None: raise ValueError( "Expected tensor (%s) to have known rank, but ndims == None." % t)
[docs]class BeamSearchDecoder(decoder.Decoder): """BeamSearch sampling decoder. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in `AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`). - The `batch_size` argument passed to the `zero_state` method of this wrapper is equal to `true_batch_size * beam_width`. - The initial state created with `zero_state` above contains a `cell_state` value containing properly tiled final state from the encoder. An example: ``` tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=beam_width) tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( encoder_final_state, multiplier=beam_width) tiled_sequence_length = tf.contrib.seq2seq.tile_batch( sequence_length, multiplier=beam_width) attention_mechanism = MyFavoriteAttentionMechanism( num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) attention_cell = AttentionWrapper(cell, attention_mechanism, ...) decoder_initial_state = attention_cell.zero_state( dtype, batch_size=true_batch_size * beam_width) decoder_initial_state = decoder_initial_state.clone( cell_state=tiled_encoder_final_state) ``` """
[docs] def __init__(self, cell, embedding, start_tokens, end_token, initial_state, beam_width, output_layer=None, length_penalty_weight=0.0, positional_embedding=None): """Initialize the BeamSearchDecoder. Args: cell: An `RNNCell` instance. embedding: A callable that takes a vector tensor of `ids` (argmax ids), or the `params` argument for `embedding_lookup`. start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. end_token: `int32` scalar, the token that marks end of decoding. initial_state: A (possibly nested tuple of...) tensors and TensorArrays. beam_width: Python integer, the number of beams. output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. positional_embedding: A callable to use decoder positional embedding. Default is None in which case positional embedding is disabled Raises: TypeError: if `cell` is not an instance of `RNNCell`, or `output_layer` is not an instance of `tf.layers.Layer`. ValueError: If `start_tokens` is not a vector or `end_token` is not a scalar. """ rnn_cell_impl.assert_like_rnncell("cell", cell) if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell self._output_layer = output_layer if callable(embedding): self._embedding_fn = embedding else: self._embedding_fn = ( lambda ids: embedding_ops.embedding_lookup(embedding, ids)) self._use_pos_embedding = False if positional_embedding is not None: if callable(positional_embedding): self._pos_embedding_fn = positional_embedding else: self._pos_embedding_fn = ( lambda ids: embedding_ops.embedding_lookup(positional_embedding, ids)) self._use_pos_embedding = True self._start_tokens = ops.convert_to_tensor( start_tokens, dtype=dtypes.int32, name="start_tokens") if self._start_tokens.get_shape().ndims != 1: raise ValueError("start_tokens must be a vector") self._end_token = ops.convert_to_tensor( end_token, dtype=dtypes.int32, name="end_token") if self._end_token.get_shape().ndims != 0: raise ValueError("end_token must be a scalar") self._batch_size = array_ops.size(start_tokens) self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._initial_cell_state = nest.map_structure( self._maybe_split_batch_beams, initial_state, self._cell.state_size) self._start_tokens = array_ops.tile( array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) self._start_inputs = self._embedding_fn(self._start_tokens) if self._use_pos_embedding: self._start_inputs += self._pos_embedding_fn(ops.convert_to_tensor(0)) self._finished = array_ops.one_hot( array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, on_value=False, off_value=True, dtype=dtypes.bool)
@property def batch_size(self): return self._batch_size def _rnn_output_size(self): size = self._cell.output_size if self._output_layer is None: return size else: # To use layer's compute_output_shape, we need to convert the # RNNCell's output_size entries into shapes with an unknown # batch size. We then pass this through the layer's # compute_output_shape and read off all but the first (batch) # dimensions to get the output size of the rnn with the layer # applied to the top. output_shape_with_unknown_batch = nest.map_structure( lambda s: tensor_shape.TensorShape([None]).concatenate(s), size) layer_output_shape = self._output_layer.compute_output_shape( output_shape_with_unknown_batch) return nest.map_structure(lambda s: s[1:], layer_output_shape) @property def tracks_own_finished(self): """The BeamSearchDecoder shuffles its beams and their finished state. For this reason, it conflicts with the `dynamic_decode` function's tracking of finished states. Setting this property to true avoids early stopping of decoding due to mismanagement of the finished state in `dynamic_decode`. Returns: `True`. """ return True @property def output_size(self): # Return the cell output and the id return BeamSearchDecoderOutput( scores=tensor_shape.TensorShape([self._beam_width]), predicted_ids=tensor_shape.TensorShape([self._beam_width]), parent_ids=tensor_shape.TensorShape([self._beam_width])) @property def output_dtype(self): # Assume the dtype of the cell is the output_size structure # containing the input_state's first component's dtype. # Return that structure and int32 (the id) dtype = nest.flatten(self._initial_cell_state)[0].dtype return BeamSearchDecoderOutput( scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), predicted_ids=dtypes.int32, parent_ids=dtypes.int32)
[docs] def initialize(self, name=None): """Initialize the decoder. Args: name: Name scope for any created operations. Returns: `(finished, start_inputs, initial_state)`. """ finished, start_inputs = self._finished, self._start_inputs dtype = nest.flatten(self._initial_cell_state)[0].dtype log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, on_value=math_ops.cast(0.0, dtype), off_value=-np.float16('inf') if dtype == dtypes.float16 else -np.Inf, dtype=dtype) initial_state = BeamSearchDecoderState( cell_state=self._initial_cell_state, log_probs=log_probs, finished=finished, lengths=array_ops.zeros( [self._batch_size, self._beam_width], dtype=dtypes.int64)) return (finished, start_inputs, initial_state)
[docs] def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. Args: outputs: An instance of BeamSearchDecoderOutput. final_state: An instance of BeamSearchDecoderState. Passed through to the output. sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. **NOTE** These are ignored; the updated sequence lengths are stored in `final_state.lengths`. Returns: outputs: An instance of `FinalBeamSearchDecoderOutput` where the predicted_ids are the result of calling _gather_tree. final_state: The same input instance of `BeamSearchDecoderState`. """ del sequence_lengths # Get max_sequence_length across all beams for each batch. max_sequence_lengths = tf.cast( math_ops.reduce_max(final_state.lengths, axis=1),tf.int32) predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=self._end_token) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state
[docs] def _merge_batch_beams(self, t, s=None): """Merges the tensor from a batch of beams into a batch by beams. More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We reshape this into [batch_size*beam_width, s] Args: t: Tensor of dimension [batch_size, beam_width, s] s: (Possibly known) depth shape. Returns: A reshaped version of t with dimension [batch_size * beam_width, s]. """ if isinstance(s, ops.Tensor): s = tensor_shape.as_shape(tensor_util.constant_value(s)) else: s = tensor_shape.TensorShape(s) t_shape = array_ops.shape(t) static_batch_size = tensor_util.constant_value(self._batch_size) batch_size_beam_width = ( None if static_batch_size is None else static_batch_size * self._beam_width) reshaped_t = array_ops.reshape( t, array_ops.concat(([self._batch_size * self._beam_width], t_shape[2:]), 0)) reshaped_t.set_shape( (tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s))) return reshaped_t
[docs] def _split_batch_beams(self, t, s=None): """Splits the tensor from a batch by beams into a batch of beams. More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We reshape this into [batch_size, beam_width, s] Args: t: Tensor of dimension [batch_size*beam_width, s]. s: (Possibly known) depth shape. Returns: A reshaped version of t with dimension [batch_size, beam_width, s]. Raises: ValueError: If, after reshaping, the new tensor is not shaped `[batch_size, beam_width, s]` (assuming batch_size and beam_width are known statically). """ if isinstance(s, ops.Tensor): s = tensor_shape.TensorShape(tensor_util.constant_value(s)) else: s = tensor_shape.TensorShape(s) t_shape = array_ops.shape(t) reshaped_t = array_ops.reshape( t, array_ops.concat(([self._batch_size, self._beam_width], t_shape[1:]), 0)) static_batch_size = tensor_util.constant_value(self._batch_size) expected_reshaped_shape = tensor_shape.TensorShape( [static_batch_size, self._beam_width]).concatenate(s) if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape): raise ValueError("Unexpected behavior when reshaping between beam width " "and batch size. The reshaped tensor has shape: %s. " "We expected it to have shape " "(batch_size, beam_width, depth) == %s. Perhaps you " "forgot to create a zero_state with " "batch_size=encoder_batch_size * beam_width?" % (reshaped_t.shape, expected_reshaped_shape)) reshaped_t.set_shape(expected_reshaped_shape) return reshaped_t
[docs] def _maybe_split_batch_beams(self, t, s): """Maybe splits the tensor from a batch by beams into a batch of beams. We do this so that we can use nest and not run into problems with shapes. Args: t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`. s: `Tensor`, Python int, or `TensorShape`. Returns: If `t` is a matrix or higher order tensor, then the return value is `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is returned unchanged. Raises: TypeError: If `t` is an instance of `TensorArray`. ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 1: return self._split_batch_beams(t, s) else: return t
[docs] def _maybe_merge_batch_beams(self, t, s): """Splits the tensor from a batch by beams into a batch of beams. More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`, then we reshape it to `[batch_size, beam_width] + s`. Args: t: `Tensor` of dimension `[batch_size * beam_width] + s`. s: `Tensor`, Python int, or `TensorShape`. Returns: A reshaped version of t with shape `[batch_size, beam_width] + s`. Raises: TypeError: If `t` is an instance of `TensorArray`. ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 2: return self._merge_batch_beams(t, s) else: return t
[docs] def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure( self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) if self._use_pos_embedding: next_inputs += self._pos_embedding_fn(ops.convert_to_tensor(time)) return (beam_search_output, beam_search_state, next_inputs, finished)
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, beam_width, end_token, length_penalty_weight): """Performs a single step of Beam Search Decoding. Args: time: Beam search time step, should start at 0. At time 0 we assume that all beams are equal and consider only the first beam for continuations. logits: Logits at the current time step. A tensor of shape `[batch_size, beam_width, vocab_size]` next_cell_state: The next state from the cell, e.g. an instance of AttentionWrapperState if the cell is attentional. beam_state: Current state of the beam search. An instance of `BeamSearchDecoderState`. batch_size: The batch size for this input. beam_width: Python int. The size of the beams. end_token: The int32 end token. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. Returns: A new beam state. """ static_batch_size = tensor_util.constant_value(batch_size) # Calculate the current lengths of the predictions prediction_lengths = beam_state.lengths previously_finished = beam_state.finished # Calculate the total log probs for the new hypotheses # Final Shape: [batch_size, beam_width, vocab_size] step_log_probs = nn_ops.log_softmax(logits) step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished) total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs # Calculate the continuation lengths by adding to all continuing beams. vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1] lengths_to_add = array_ops.one_hot( indices=array_ops.fill([batch_size, beam_width], end_token), depth=vocab_size, on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) add_mask = tf.cast(math_ops.logical_not(previously_finished), tf.int64) lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = ( lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) # Calculate the scores for each beam scores = _get_scores( log_probs=total_probs, sequence_lengths=new_prediction_lengths, length_penalty_weight=length_penalty_weight, dtype=logits.dtype) time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam scores_shape = array_ops.shape(scores) scores_flat = array_ops.reshape(scores, [batch_size, -1]) # Pick the next beams according to the specified successors function next_beam_size = ops.convert_to_tensor( beam_width, dtype=dtypes.int32, name="beam_width") next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) next_beam_scores.set_shape([static_batch_size, beam_width]) word_indices.set_shape([static_batch_size, beam_width]) # Pick out the probs, beam_ids, and states according to the chosen # predictions next_beam_probs = _tensor_gather_helper( gather_indices=word_indices, gather_from=total_probs, batch_size=batch_size, range_size=beam_width * vocab_size, gather_shape=[-1], name="next_beam_probs") # Note: just doing the following # math_ops.to_int32(word_indices % vocab_size, # name="next_beam_word_ids") # would be a lot cleaner but for reasons unclear, that hides the results of # the op which prevents capturing it with tfdbg debug ops. raw_next_word_ids = math_ops.mod( word_indices, vocab_size, name="next_beam_word_ids") next_word_ids = tf.cast(raw_next_word_ids, tf.int32) next_beam_ids = tf.cast(word_indices / vocab_size, name="next_beam_parent_ids", dtype=tf.int32) # Append new ids to current predictions previously_finished = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=previously_finished, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) next_finished = math_ops.logical_or( previously_finished, math_ops.equal(next_word_ids, end_token), name="next_beam_finished") # Calculate the length of the next predictions. # 1. Finished beams remain unchanged. # 2. Beams that are now finished (EOS predicted) have their length # increased by 1. # 3. Beams that are not yet finished have their length increased by 1. lengths_to_add = tf.cast(math_ops.logical_not(previously_finished), tf.int64) next_prediction_len = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) next_prediction_len += lengths_to_add # Pick out the cell_states according to the next_beam_ids. We use a # different gather_shape here because the cell_state tensors, i.e. # the tensors that would be gathered from, all have dimension # greater than two and we need to preserve those dimensions. # pylint: disable=g-long-lambda next_cell_state = nest.map_structure( lambda gather_from: _maybe_tensor_gather_helper( gather_indices=next_beam_ids, gather_from=gather_from, batch_size=batch_size, range_size=beam_width, gather_shape=[batch_size * beam_width, -1]), next_cell_state) # pylint: enable=g-long-lambda next_state = BeamSearchDecoderState( cell_state=next_cell_state, log_probs=next_beam_probs, lengths=next_prediction_len, finished=next_finished) output = BeamSearchDecoderOutput( scores=next_beam_scores, predicted_ids=next_word_ids, parent_ids=next_beam_ids) return output, next_state def _get_scores(log_probs, sequence_lengths, length_penalty_weight, dtype=dtypes.float32): """Calculates scores for beam search hypotheses. Args: log_probs: The log probabilities with shape `[batch_size, beam_width, vocab_size]`. sequence_lengths: The array of sequence lengths. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. Returns: The scores normalized by the length_penalty. """ length_penality_ = _length_penalty( sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) return log_probs / math_ops.cast(length_penality_, dtype) def _length_penalty(sequence_lengths, penalty_factor): """Calculates the length penalty. See https://arxiv.org/abs/1609.08144. Returns the length penalty tensor: ``` [(5+sequence_lengths)/6]**penalty_factor ``` where all operations are performed element-wise. Args: sequence_lengths: `Tensor`, the sequence lengths of each hypotheses. penalty_factor: A scalar that weights the length penalty. Returns: If the penalty is `0`, returns the scalar `1.0`. Otherwise returns the length penalty factor, a tensor with the same shape as `sequence_lengths`. """ penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor") penalty_factor.set_shape(()) # penalty should be a scalar. static_penalty = tensor_util.constant_value(penalty_factor) if static_penalty is not None and static_penalty == 0: return 1.0 return math_ops.div((5. + math_ops.to_float(sequence_lengths)) ** penalty_factor, (5. + 1.)**penalty_factor) def _mask_probs(probs, eos_token, finished): """Masks log probabilities. The result is that finished beams allocate all probability mass to eos and unfinished beams remain unchanged. Args: probs: Log probabiltiies of shape `[batch_size, beam_width, vocab_size]` eos_token: An int32 id corresponding to the EOS token to allocate probability to. finished: A boolean tensor of shape `[batch_size, beam_width]` that specifies which elements in the beam are finished already. Returns: A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished beams stay unchanged and finished beams are replaced with a tensor with all probability on the EOS token. """ vocab_size = array_ops.shape(probs)[2] # All finished examples are replaced with a vector that has all # probability on EOS finished_row = array_ops.one_hot( eos_token, vocab_size, dtype=probs.dtype, on_value=ops.convert_to_tensor(0., dtype=probs.dtype), off_value=probs.dtype.min) finished_probs = array_ops.tile( array_ops.reshape(finished_row, [1, 1, -1]), array_ops.concat([array_ops.shape(finished), [1]], 0)) finished_mask = array_ops.tile( array_ops.expand_dims(finished, 2), [1, 1, vocab_size]) return array_ops.where(finished_mask, finished_probs, probs) def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, range_size, gather_shape): """Maybe applies _tensor_gather_helper. This applies _tensor_gather_helper when the gather_from dims is at least as big as the length of gather_shape. This is used in conjunction with nest so that we don't apply _tensor_gather_helper to inapplicable values like scalars. Args: gather_indices: The tensor indices that we use to gather. gather_from: The tensor that we are gathering from. batch_size: The batch size. range_size: The number of values in each range. Likely equal to beam_width. gather_shape: What we should reshape gather_from to in order to preserve the correct values. An example is when gather_from is the attention from an AttentionWrapperState with shape [batch_size, beam_width, attention_size]. There, we want to preserve the attention_size elements, so gather_shape is [batch_size * beam_width, -1]. Then, upon reshape, we still have the attention_size as desired. Returns: output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)] or the original tensor if its dimensions are too small. """ _check_maybe(gather_from) if gather_from.shape.ndims >= len(gather_shape): return _tensor_gather_helper( gather_indices=gather_indices, gather_from=gather_from, batch_size=batch_size, range_size=range_size, gather_shape=gather_shape) else: return gather_from def _tensor_gather_helper(gather_indices, gather_from, batch_size, range_size, gather_shape, name=None): """Helper for gathering the right indices from the tensor. This works by reshaping gather_from to gather_shape (e.g. [-1]) and then gathering from that according to the gather_indices, which are offset by the right amounts in order to preserve the batch order. Args: gather_indices: The tensor indices that we use to gather. gather_from: The tensor that we are gathering from. batch_size: The input batch size. range_size: The number of values in each range. Likely equal to beam_width. gather_shape: What we should reshape gather_from to in order to preserve the correct values. An example is when gather_from is the attention from an AttentionWrapperState with shape [batch_size, beam_width, attention_size]. There, we want to preserve the attention_size elements, so gather_shape is [batch_size * beam_width, -1]. Then, upon reshape, we still have the attention_size as desired. name: The tensor name for set of operations. By default this is 'tensor_gather_helper'. The final output is named 'output'. Returns: output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)] """ with ops.name_scope(name, "tensor_gather_helper"): range_ = array_ops.expand_dims(math_ops.range(batch_size) * range_size, 1) gather_indices = array_ops.reshape(gather_indices + range_, [-1]) output = array_ops.gather( array_ops.reshape(gather_from, gather_shape), gather_indices) final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)] static_batch_size = tensor_util.constant_value(batch_size) final_static_shape = ( tensor_shape.TensorShape([static_batch_size]).concatenate( gather_from.shape[1:1 + len(gather_shape)])) output = array_ops.reshape(output, final_shape, name="output") output.set_shape(final_static_shape) return output