"""Implement https://arxiv.org/abs/1709.02755
Copy from LSTM, and make it functionally correct with minimum code change
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from six.moves import range
import tensorflow as tf
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import rnn_cell
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
_BIAS_VARIABLE_NAME = "biases" if tf.__version__ < "1.2.0" else "bias"
_WEIGHTS_VARIABLE_NAME = "weights" if tf.__version__ < "1.2.0" else "kernel"
# TODO: must implement all abstract methods
[docs]class BasicSLSTMCell(rnn_cell.RNNCell):
"""Basic SLSTM recurrent network cell.
The implementation is based on: https://arxiv.org/abs/1709.02755.
[docs] def __init__(self, num_units, forget_bias=1.0,
state_is_tuple=True, activation=None, reuse=None):
"""Initialize the basic SLSTM cell.
num_units: int, The number of units in the SLSTM cell.
forget_bias: float, The bias added to forget gates (see above).
Must set to `0.0` manually when restoring from CudnnLSTM-trained
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
super(BasicSLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation or math_ops.tanh
def state_size(self):
return (rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)
def output_size(self):
return self._num_units
# TODO: does not match signature of the base method
[docs] def call(self, inputs, state):
"""Long short-term memory cell (LSTM).
inputs: `2-D` tensor with shape `[batch_size x input_size]`.
state: An `LSTMStateTuple` of state tensors, each shaped
`[batch_size x self.state_size]`, if `state_is_tuple` has been set to
`True`. Otherwise, a `Tensor` shaped
`[batch_size x 2 * self.state_size]`.
A pair containing the new hidden state, and the new state (either a
`LSTMStateTuple` or a concatenated state, depending on
sigmoid = math_ops.sigmoid
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
# concat = _linear([inputs, h], 4 * self._num_units, True)
concat = _linear(inputs, 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
new_c = (
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = rnn_cell.LSTMStateTuple(new_c, new_h)
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
[docs]def _linear(args,
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
output_size: int, second dimension of W[i].
bias: boolean, whether to add a bias term or not.
bias_initializer: starting value to initialize the bias
(default is all zeros).
kernel_initializer: starting value to initialize the weight.
A 2D Tensor with shape [batch x output_size] equal to
sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
ValueError: if some of the arguments has unspecified or wrong shape.
if args is None or (nest.is_sequence(args) and not args):
raise ValueError("`args` must be specified")
if not nest.is_sequence(args):
args = [args]
# Calculate the total size of arguments on dimension 1.
total_arg_size = 0
shapes = [a.get_shape() for a in args]
for shape in shapes:
if shape.ndims != 2:
raise ValueError("linear is expecting 2D arguments: %s" % shapes)
if shape[1].value is None:
raise ValueError("linear expects shape[1] to be provided for shape %s, "
"but saw %s" % (shape, shape[1]))
total_arg_size += shape[1].value
dtype = [a.dtype for a in args][0]
# Now the computation.
scope = vs.get_variable_scope()
with vs.variable_scope(scope) as outer_scope:
weights = vs.get_variable(
_WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size],
if len(args) == 1:
res = math_ops.matmul(args[0], weights)
res = math_ops.matmul(array_ops.concat(args, 1), weights)
if not bias:
return res
with vs.variable_scope(outer_scope) as inner_scope:
if bias_initializer is None:
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
biases = vs.get_variable(
_BIAS_VARIABLE_NAME, [output_size],
return nn_ops.bias_add(res, biases)