Source code for parts.rnns.slstm

"""Implement https://arxiv.org/abs/1709.02755

Copy from LSTM, and make it functionally correct with minimum code change
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from six.moves import range

import tensorflow as tf
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import rnn_cell
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest

_BIAS_VARIABLE_NAME = "biases" if tf.__version__ < "1.2.0" else "bias"
_WEIGHTS_VARIABLE_NAME = "weights" if tf.__version__ < "1.2.0" else "kernel"


# TODO: must implement all abstract methods
[docs]class BasicSLSTMCell(rnn_cell.RNNCell): """Basic SLSTM recurrent network cell. The implementation is based on: https://arxiv.org/abs/1709.02755. """
[docs] def __init__(self, num_units, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None): """Initialize the basic SLSTM cell. Args: num_units: int, The number of units in the SLSTM cell. forget_bias: float, The bias added to forget gates (see above). Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. activation: Activation function of the inner states. Default: `tanh`. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(BasicSLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation or math_ops.tanh
@property def state_size(self): return (rnn_cell.LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units) @property def output_size(self): return self._num_units # TODO: does not match signature of the base method
[docs] def call(self, inputs, state): """Long short-term memory cell (LSTM). Args: inputs: `2-D` tensor with shape `[batch_size x input_size]`. state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size x self.state_size]`, if `state_is_tuple` has been set to `True`. Otherwise, a `Tensor` shaped `[batch_size x 2 * self.state_size]`. Returns: A pair containing the new hidden state, and the new state (either a `LSTMStateTuple` or a concatenated state, depending on `state_is_tuple`). """ sigmoid = math_ops.sigmoid # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state else: c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) # concat = _linear([inputs, h], 4 * self._num_units, True) concat = _linear(inputs, 4 * self._num_units, True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) new_c = ( c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = rnn_cell.LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state
[docs]def _linear(args, output_size, bias, bias_initializer=None, kernel_initializer=None): """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. Args: args: a 2D Tensor or a list of 2D, batch x n, Tensors. output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. bias_initializer: starting value to initialize the bias (default is all zeros). kernel_initializer: starting value to initialize the weight. Returns: A 2D Tensor with shape [batch x output_size] equal to sum_i(args[i] * W[i]), where W[i]s are newly created matrices. Raises: ValueError: if some of the arguments has unspecified or wrong shape. """ if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] # Calculate the total size of arguments on dimension 1. total_arg_size = 0 shapes = [a.get_shape() for a in args] for shape in shapes: if shape.ndims != 2: raise ValueError("linear is expecting 2D arguments: %s" % shapes) if shape[1].value is None: raise ValueError("linear expects shape[1] to be provided for shape %s, " "but saw %s" % (shape, shape[1])) else: total_arg_size += shape[1].value dtype = [a.dtype for a in args][0] # Now the computation. scope = vs.get_variable_scope() with vs.variable_scope(scope) as outer_scope: weights = vs.get_variable( _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], dtype=dtype, initializer=kernel_initializer) if len(args) == 1: res = math_ops.matmul(args[0], weights) else: res = math_ops.matmul(array_ops.concat(args, 1), weights) if not bias: return res with vs.variable_scope(outer_scope) as inner_scope: inner_scope.set_partitioner(None) if bias_initializer is None: bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) biases = vs.get_variable( _BIAS_VARIABLE_NAME, [output_size], dtype=dtype, initializer=bias_initializer) return nn_ops.bias_add(res, biases)