"""Module for constructing RNN Cells."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from six.moves import range
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
# pylint: disable=protected-access
_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
# pylint: enable=protected-access
# TODO: must implement all abstract methods
[docs]class GLSTMCell(rnn_cell_impl.RNNCell):
"""Group LSTM cell (G-LSTM).
The implementation is based on:
https://arxiv.org/abs/1703.10722
O. Kuchaiev and B. Ginsburg
"Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
"""
[docs] def __init__(self, num_units, initializer=None, num_proj=None,
number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh,
reuse=None):
"""Initialize the parameters of G-LSTM cell.
Args:
num_units: int, The number of units in the G-LSTM cell
initializer: (optional) The initializer to use for the weight and
projection matrices.
num_proj: (optional) int, The output dimensionality for the projection
matrices. If None, no projection is performed.
number_of_groups: (optional) int, number of groups to use.
If `number_of_groups` is 1, then it should be equivalent to LSTM cell
forget_bias: Biases of the forget gate are initialized by default to 1
in order to reduce the scale of forgetting at the beginning of
the training.
activation: Activation function of the inner states.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already
has the given variables, an error is raised.
Raises:
ValueError: If `num_units` or `num_proj` is not divisible by
`number_of_groups`.
"""
super(GLSTMCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._initializer = initializer
self._num_proj = num_proj
self._forget_bias = forget_bias
self._activation = activation
self._number_of_groups = number_of_groups
if self._num_units % self._number_of_groups != 0:
raise ValueError("num_units must be divisible by number_of_groups")
if self._num_proj:
if self._num_proj % self._number_of_groups != 0:
raise ValueError("num_proj must be divisible by number_of_groups")
self._group_shape = [int(self._num_proj / self._number_of_groups),
int(self._num_units / self._number_of_groups)]
else:
self._group_shape = [int(self._num_units / self._number_of_groups),
int(self._num_units / self._number_of_groups)]
if num_proj:
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
self._output_size = num_proj
else:
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
self._output_size = num_units
self._linear1 = [None] * self._number_of_groups
self._linear2 = None
@property
def state_size(self):
return self._state_size
@property
def output_size(self):
return self._output_size
# TODO: does not match signature of the base method
[docs] def call(self, inputs, state):
"""Run one step of G-LSTM.
Args:
inputs: input Tensor, 2D, [batch x num_units].
state: this must be a tuple of state Tensors, both `2-D`,
with column sizes `c_state` and `m_state`.
Returns:
A tuple containing:
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
G-LSTM after reading `inputs` when previous state was `state`.
Here output_dim is:
num_proj if num_proj was set,
num_units otherwise.
- LSTMStateTuple representing the new state of G-LSTM cell
after reading `inputs` when the previous state was `state`.
Raises:
ValueError: If input size cannot be inferred from inputs via
static shape inference.
"""
(c_prev, m_prev) = state
self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
input_size = inputs.shape[-1].value or array_ops.shape(inputs)[-1]
dtype = inputs.dtype
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer):
i_parts = []
j_parts = []
f_parts = []
o_parts = []
for group_id in range(self._number_of_groups):
with vs.variable_scope("group%d" % group_id):
x_g_id = array_ops.concat(
[self._get_input_for_group(inputs, group_id,
int(input_size / self._number_of_groups)),
# self._group_shape[0]), # this is only correct if inputs dim = num_units!!!
self._get_input_for_group(m_prev, group_id,
int(self._output_size / self._number_of_groups))], axis=1)
# self._group_shape[0])], axis=1)
if self._linear1[group_id] is None:
self._linear1[group_id] = _Linear(
x_g_id, 4 * self._group_shape[1],
False,
)
R_k = self._linear1[group_id](x_g_id) # pylint: disable=invalid-name
i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)
i_parts.append(i_k)
j_parts.append(j_k)
f_parts.append(f_k)
o_parts.append(o_k)
bi = vs.get_variable(
name="bias_i",
shape=[self._num_units],
dtype=dtype,
initializer=init_ops.constant_initializer(0.0, dtype=dtype),
)
bj = vs.get_variable(
name="bias_j",
shape=[self._num_units],
dtype=dtype,
initializer=init_ops.constant_initializer(0.0, dtype=dtype),
)
bf = vs.get_variable(
name="bias_f",
shape=[self._num_units],
dtype=dtype,
initializer=init_ops.constant_initializer(0.0, dtype=dtype),
)
bo = vs.get_variable(
name="bias_o",
shape=[self._num_units],
dtype=dtype,
initializer=init_ops.constant_initializer(0.0, dtype=dtype),
)
i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
c = (math_ops.sigmoid(f + self._forget_bias) * c_prev +
math_ops.sigmoid(i) * math_ops.tanh(j))
m = math_ops.sigmoid(o) * self._activation(c)
if self._num_proj is not None:
with vs.variable_scope("projection"):
if self._linear2 is None:
self._linear2 = _Linear(m, self._num_proj, False)
m = self._linear2(m)
new_state = rnn_cell_impl.LSTMStateTuple(c, m)
return m, new_state