Source code for parts.rnns.utils

# Copyright (c) 2017 NVIDIA Corporation
from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals

import math

from six.moves import range
import tensorflow as tf

from tensorflow.python.ops.rnn_cell import ResidualWrapper, DropoutWrapper
from open_seq2seq.parts.rnns.weight_drop import WeightDropLayerNormBasicLSTMCell
from open_seq2seq.parts.rnns.slstm import BasicSLSTMCell
from open_seq2seq.parts.rnns.glstm import GLSTMCell
from open_seq2seq.parts.rnns.zoneout import ZoneoutWrapper


[docs]def single_cell( cell_class, cell_params, dp_input_keep_prob=1.0, dp_output_keep_prob=1.0, recurrent_keep_prob=1.0, input_weight_keep_prob=1.0, recurrent_weight_keep_prob=1.0, weight_variational=False, dropout_seed=None, zoneout_prob=0., training=True, residual_connections=False, awd_initializer=False, variational_recurrent=False, # in case they want to use DropoutWrapper dtype=None, ): """Creates an instance of the rnn cell. Such cell describes one step one layer and can include residual connection and/or dropout Args: cell_class: Tensorflow RNN cell class cell_params (dict): cell parameters dp_input_keep_prob (float): (default: 1.0) input dropout keep probability. dp_output_keep_prob (float): (default: 1.0) output dropout keep probability. zoneout_prob(float): zoneout probability. Applying both zoneout and droupout is currently not supported residual_connections (bool): whether to add residual connection Returns: TF RNN instance """ if awd_initializer: val = 1.0/math.sqrt(cell_params['num_units']) cell_params['initializer'] = tf.random_uniform_initializer(minval=-val, maxval=val) if 'WeightDropLayerNormBasicLSTMCell' in str(cell_class): if recurrent_keep_prob < 1.0: cell_params['recurrent_keep_prob'] = recurrent_keep_prob if input_weight_keep_prob < 1.0: cell_params['input_weight_keep_prob'] = input_weight_keep_prob if recurrent_weight_keep_prob < 1.0: cell_params['recurrent_weight_keep_prob'] = recurrent_weight_keep_prob if weight_variational: cell_params['weight_variational'] = weight_variational # which is basically True if dropout_seed: cell_params['dropout_seed'] = dropout_seed cell = cell_class(**cell_params) if residual_connections: cell = ResidualWrapper(cell) if zoneout_prob > 0. and ( dp_input_keep_prob < 1.0 or dp_output_keep_prob < 1.0 ): raise ValueError( "Currently applying both dropout and zoneout on the same cell." "This is currently not supported." ) if dp_input_keep_prob != 1.0 or dp_output_keep_prob != 1.0 and training: cell = DropoutWrapper( cell, input_keep_prob=dp_input_keep_prob, output_keep_prob=dp_output_keep_prob, variational_recurrent=variational_recurrent, dtype=dtype, seed=dropout_seed ) if zoneout_prob > 0.: cell = ZoneoutWrapper(cell, zoneout_prob, is_training=training) return cell