encoders

This package contains various encoders. An encoder typically takes data and produces representation.

encoder

class encoders.encoder.Encoder(params, model, name='encoder', mode='train')[source]

Bases: object

Abstract class from which all encoders must inherit.

__init__(params, model, name='encoder', mode='train')[source]

Encoder constructor. Note that encoder constructors should not modify TensorFlow graph, all graph construction should happen in the self._encode() method.

Parameters:
  • params (dict) – parameters describing the encoder. All supported parameters are listed in get_required_params(), get_optional_params() functions.
  • model (instance of a class derived from Model) – parent model that created this encoder. Could be None if no model access is required for the use case.
  • name (str) – name for encoder variable scope.
  • mode (str) – mode encoder is going to be run in. Could be “train”, “eval” or “infer”.

Config parameters:

  • initializer — any valid TensorFlow initializer. If no initializer is provided, model initializer will be used.
  • initializer_params (dict) — dictionary that will be passed to initializer __init__ method.
  • regularizer — and valid TensorFlow regularizer. If no regularizer is provided, model regularizer will be used.
  • regularizer_params (dict) — dictionary that will be passed to regularizer __init__ method.
  • dtype — model dtype. Could be either tf.float16, tf.float32 or “mixed”. For details see mixed precision training section in docs. If no dtype is provided, model dtype will be used.
_cast_types(input_dict)[source]

This function performs automatic cast of all inputs to encoder dtype.

Parameters:input_dict (dict) – dictionary passed to self._encode() method.
Returns:same as input_dict, but with all Tensors cast to encoder dtype.
Return type:dict
_encode(input_dict)[source]

This is the main function which should construct encoder graph. Typically, encoder will take raw input sequence as an input and produce some hidden representation as an output.

Parameters:input_dict (dict) –

dictionary containing encoder inputs. If the encoder is used with models.encoder_decoder class, input_dict will have the following content:

{
  "source_tensors": data_layer.input_tensors['source_tensors']
}
Returns:dictionary of encoder outputs. Return all necessary outputs. Typically this will be just:
{
  "outputs": outputs,
  "state": state,
}
Return type:dict
encode(input_dict)[source]

Wrapper around self._encode() method. Here name, initializer and dtype are set in the variable scope and then self._encode() method is called.

Parameters:input_dict (dict) – see self._encode() docs.
Returns:see self._encode() docs.
static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict
mode

Mode encoder is run in.

name

Encoder name.

params

Parameters used to construct the encoder (dictionary).

ds2_encoder

class encoders.ds2_encoder.DeepSpeech2Encoder(params, model, name='ds2_encoder', mode='train')[source]

Bases: encoders.encoder.Encoder

DeepSpeech-2 like encoder.

__init__(params, model, name='ds2_encoder', mode='train')[source]

DeepSpeech-2 like encoder constructor.

See parent class for arguments description.

Config parameters:

  • dropout_keep_prop (float) — keep probability for dropout.

  • conv_layers (list) — list with the description of convolutional layers. For example:

    "conv_layers": [
      {
        "kernel_size": [11, 41], "stride": [2, 2],
        "num_channels": 32, "padding": "SAME",
      },
      {
        "kernel_size": [11, 21], "stride": [1, 2],
        "num_channels": 64, "padding": "SAME",
      },
      {
        "kernel_size": [11, 21], "stride": [1, 2],
        "num_channels": 96, "padding": "SAME",
      },
    ]
    
  • activation_fn — activation function to use.

  • num_rnn_layers — number of RNN layers to use.

  • rnn_type (string) — could be “lstm”, “gru”, “cudnn_gru”, “cudnn_lstm” or “layernorm_lstm”.

  • rnn_unidirectional (bool) — whether to use uni-directional or bi-directional RNNs.

  • rnn_cell_dim (int) — dimension of RNN cells.

  • row_conv (bool) — whether to use a “row” (“in plane”) convolutional layer after RNNs.

  • row_conv_width (int) — width parameter for “row” convolutional layer.

  • n_hidden (int) — number of hidden units for the last fully connected layer.

  • data_format (string) — could be either “channels_first”, “channels_last”, “BCTF”, “BTFC”, “BCFT”, “BFTC”.

    Defaults to “channels_last”.

  • bn_momentum (float) — momentum for batch norm. Defaults to 0.99.

  • bn_epsilon (float) — epsilon for batch norm. Defaults to 1e-3.

_encode(input_dict)[source]

Creates TensorFlow graph for DeepSpeech-2 like encoder.

Parameters:input_dict (dict) –

input dictionary that has to contain the following fields:

input_dict = {
  "source_tensors": [
    src_sequence (shape=[batch_size, sequence length, num features]),
    src_length (shape=[batch_size])
  ]
}
Returns:dictionary with the following tensors:
{
  'outputs': hidden state, shape=[batch_size, sequence length, n_hidden]
  'src_length': tensor, shape=[batch_size]
}
Return type:dict
static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict
encoders.ds2_encoder.rnn_cell(rnn_cell_dim, layer_type, dropout_keep_prob=1.0)[source]

Helper function that creates RNN cell.

encoders.ds2_encoder.row_conv(name, input_layer, batch, channels, width, activation_fn, regularizer, training, data_format, bn_momentum, bn_epsilon)[source]

Helper function that applies “row” or “in plane” convolution.

tdnn_encoder

class encoders.tdnn_encoder.TDNNEncoder(params, model, name='w2l_encoder', mode='train')[source]

Bases: encoders.encoder.Encoder

General time delay neural network (TDNN) encoder. Fully convolutional model

__init__(params, model, name='w2l_encoder', mode='train')[source]

TDNN encoder constructor.

See parent class for arguments description.

Config parameters:

  • dropout_keep_prob (float) — keep probability for dropout.

  • convnet_layers (list) — list with the description of convolutional layers. For example:

    "convnet_layers": [
      {
        "type": "conv1d", "repeat" : 5,
        "kernel_size": [7], "stride": [1],
        "num_channels": 250, "padding": "SAME"
      },
      {
        "type": "conv1d", "repeat" : 3,
        "kernel_size": [11], "stride": [1],
        "num_channels": 500, "padding": "SAME"
      },
      {
        "type": "conv1d", "repeat" : 1,
        "kernel_size": [32], "stride": [1],
        "num_channels": 1000, "padding": "SAME"
      },
      {
        "type": "conv1d", "repeat" : 1,
        "kernel_size": [1], "stride": [1],
        "num_channels": 1000, "padding": "SAME"
      },
    ]
    
  • activation_fn — activation function to use.

  • data_format (string) — could be either “channels_first” or “channels_last”. Defaults to “channels_last”.

  • normalization — normalization to use. Accepts [None, ‘batch_norm’]. Use None if you don’t want to use normalization. Defaults to ‘batch_norm’.

  • bn_momentum (float) — momentum for batch norm. Defaults to 0.90.

  • bn_epsilon (float) — epsilon for batch norm. Defaults to 1e-3.

  • drop_block_prob (float) — probability of dropping encoder blocks. Defaults to 0.0 which corresponds to training without dropping blocks.

  • drop_block_index (int) – index of the block to drop on inference. Defaults to -1 which corresponds to keeping all blocks.

  • use_conv_mask (bool) — whether to apply a sequence mask prior to convolution operations. Defaults to False for backwards compatibility. Recommended to set as True

_encode(input_dict)[source]

Creates TensorFlow graph for Wav2Letter like encoder.

Parameters:input_dict (dict) –

input dictionary that has to contain the following fields:

input_dict = {
  "source_tensors": [
    src_sequence (shape=[batch_size, sequence length, num features]),
    src_length (shape=[batch_size])
  ]
}
Returns:dictionary with the following tensors:
{
  'outputs': hidden state, shape=[batch_size, sequence length, n_hidden]
  'src_length': tensor, shape=[batch_size]
}
Return type:dict
static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict

rnn_encoders

RNN-based encoders

class encoders.rnn_encoders.BidirectionalRNNEncoderWithEmbedding(params, model, name='bidir_rnn_encoder_with_emb', mode='train')[source]

Bases: encoders.encoder.Encoder

Bi-directional RNN-based encoder with embeddings. Can support various RNN cell types.

__init__(params, model, name='bidir_rnn_encoder_with_emb', mode='train')[source]

Initializes bi-directional encoder with embeddings.

Parameters:params (dict) –

dictionary with encoder parameters Must define:

  • src_vocab_size - data vocabulary size
  • src_emb_size - size of embedding to use
  • encoder_cell_units - number of units in RNN cell
  • encoder_cell_type - cell type: lstm, gru, etc.
  • encoder_layers - number of layers
  • encoder_dp_input_keep_prob -
  • encoder_dp_output_keep_prob -
  • encoder_use_skip_connections - true/false
  • time_major (optional)
  • use_swap_memory (optional)
  • mode - train or infer

… add any cell-specific parameters here as well

Returns:encoder_params
_encode(input_dict)[source]

Encodes data into representation. :param input_dict: a Python dictionary.

Must define:
*src_inputs - a Tensor of shape [batch_size, time] or
[time, batch_size] (depending on time_major param)
  • src_lengths - a Tensor of shape [batch_size]
Returns:
  • encoder_outputs - a Tensor of shape
    [batch_size, time, representation_dim]

or [time, batch_size, representation_dim] * encoder_state - a Tensor of shape [batch_size, dim] * src_lengths - (copy ref from input) a Tensor of shape [batch_size]

Return type:a Python dictionary with
enc_emb_w
static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict
src_emb_size
src_vocab_size
class encoders.rnn_encoders.GNMTLikeEncoderWithEmbedding(params, model, name='gnmt_encoder_with_emb', mode='train')[source]

Bases: encoders.encoder.Encoder

Encoder similar to the one used in GNMT model: https://arxiv.org/abs/1609.08144. Must have at least 2 layers

__init__(params, model, name='gnmt_encoder_with_emb', mode='train')[source]

Encodes data into representation.

Parameters:params (dict) –

a Python dictionary. Must define:

  • src_inputs - a Tensor of shape [batch_size, time] or
    [time, batch_size] (depending on time_major param)
  • src_lengths - a Tensor of shape [batch_size]
Returns:
  • encoder_outputs - a Tensor of shape
    [batch_size, time, representation_dim]

or [time, batch_size, representation_dim] * encoder_state - a Tensor of shape [batch_size, dim] * src_lengths - (copy ref from input) a Tensor of shape [batch_size]

Return type:a Python dictionary with
enc_emb_w
static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict
src_emb_size
src_vocab_size
class encoders.rnn_encoders.GNMTLikeEncoderWithEmbedding_cuDNN(params, model, name='gnmt_encoder_with_emb_cudnn', mode='train')[source]

Bases: encoders.encoder.Encoder

Encoder similar to the one used in GNMT model: https://arxiv.org/abs/1609.08144. Must have at least 2 layers. Uses cuDNN RNN blocks for efficiency

__init__(params, model, name='gnmt_encoder_with_emb_cudnn', mode='train')[source]

Encodes data into representation

Parameters:params (dict) –

a Python dictionary. Must define:

  • src_inputs - a Tensor of shape [batch_size, time] or
    [time, batch_size] (depending on time_major param)
  • src_lengths - a Tensor of shape [batch_size]
Returns:
  • encoder_outputs - a Tensor of shape
    [batch_size, time, representation_dim]

or [time, batch_size, representation_dim] * encoder_state - a Tensor of shape [batch_size, dim] * src_lengths - (copy ref from input) a Tensor of shape [batch_size]

Return type:a Python dictionary with
enc_emb_w
static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict
src_emb_size
src_vocab_size
class encoders.rnn_encoders.UnidirectionalRNNEncoderWithEmbedding(params, model, name='unidir_rnn_encoder_with_emb', mode='train')[source]

Bases: encoders.encoder.Encoder

Uni-directional RNN decoder with embeddings. Can support various RNN cell types.

__init__(params, model, name='unidir_rnn_encoder_with_emb', mode='train')[source]

Initializes uni-directional encoder with embeddings.

Parameters:params (dict) –

dictionary with encoder parameters Must define:

  • src_vocab_size - data vocabulary size
  • src_emb_size - size of embedding to use
  • encoder_cell_units - number of units in RNN cell
  • encoder_cell_type - cell type: lstm, gru, etc.
  • encoder_layers - number of layers
  • encoder_dp_input_keep_prob -
  • encoder_dp_output_keep_prob -
  • encoder_use_skip_connections - true/false
  • time_major (optional)
  • use_swap_memory (optional)
  • mode - train or infer

… add any cell-specific parameters here as well

_encode(input_dict)[source]

Encodes data into representation.

Parameters:input_dict

a Python dictionary. Must define:

  • src_inputs - a Tensor of shape [batch_size, time] or
    [time, batch_size] (depending on time_major param)
  • src_lengths - a Tensor of shape [batch_size]
Returns:
Return type:a Python dictionary with
  • encoder_outputs - a Tensor of shape
    [batch_size, time, representation_dim]

or [time, batch_size, representation_dim] * encoder_state - a Tensor of shape [batch_size, dim] * src_lengths - (copy ref from input) a Tensor of shape [batch_size]

enc_emb_w
static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict
src_emb_size
src_vocab_size

transformer_encoder

class encoders.transformer_encoder.TransformerEncoder(params, model, name='transformer_encoder', mode='train')[source]

Bases: open_seq2seq.encoders.encoder.Encoder

Transformer model encoder

static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict

convs2s_encoder

Conv-based encoder

class encoders.convs2s_encoder.ConvS2SEncoder(params, model, name='convs2s_encoder_with_emb', mode='train')[source]

Bases: encoders.encoder.Encoder

Fully convolutional Encoder of ConvS2S

static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict
src_emb_size
src_vocab_size

resnet_encoder

class encoders.resnet_encoder.ResNetEncoder(params, model, name='resnet_encoder', mode='train')[source]

Bases: encoders.encoder.Encoder

static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict

resnet_blocks

Contains definitions for Residual Networks.

Residual networks (‘v1’ ResNets) were originally proposed in: [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun

Deep Residual Learning for Image Recognition. arXiv:1512.03385

The full preactivation ‘v2’ ResNet variant was introduced by: [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun

Identity Mappings in Deep Residual Networks. arXiv: 1603.05027

The key difference of the full preactivation ‘v2’ variant compared to the ‘v1’ variant in [1] is the use of batch normalization before every weight layer rather than after.

encoders.resnet_blocks.batch_norm(inputs, training, data_format, regularizer, momentum, epsilon)[source]

Performs a batch normalization using a standard set of parameters.

encoders.resnet_blocks.block_layer(inputs, filters, bottleneck, block_fn, blocks, strides, training, name, data_format, regularizer, bn_regularizer, bn_momentum, bn_epsilon)[source]

Creates one layer of blocks for the ResNet model.

Parameters:
  • inputs – A tensor of size [batch, channels, height_in, width_in] or [batch, height_in, width_in, channels] depending on data_format.
  • filters – The number of filters for the first convolution of the layer.
  • bottleneck – Is the block created a bottleneck block.
  • block_fn – The block to use within the model, either building_block or bottleneck_block.
  • blocks – The number of blocks contained in the layer.
  • strides – The stride to use for the first convolution of the layer. If greater than 1, this layer will ultimately downsample the input.
  • training – Either True or False, whether we are currently training the model. Needed for batch norm.
  • name – A string name for the tensor output of the block layer.
  • data_format – The input format (‘channels_last’ or ‘channels_first’).
Returns:

The output tensor of the block layer.

encoders.resnet_blocks.bottleneck_block_v1(inputs, filters, training, projection_shortcut, strides, data_format, regularizer, bn_regularizer, bn_momentum, bn_epsilon)[source]

A single block for ResNet v1, with a bottleneck.

Similar to _building_block_v1(), except using the “bottleneck” blocks described in:

Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition https://arxiv.org/pdf/1512.03385.pdf by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
Parameters:
  • inputs – A tensor of size [batch, channels, height_in, width_in] or [batch, height_in, width_in, channels] depending on data_format.
  • filters – The number of filters for the convolutions.
  • training – A Boolean for whether the model is in training or inference mode. Needed for batch normalization.
  • projection_shortcut – The function to use for projection shortcuts (typically a 1x1 convolution when downsampling the input).
  • strides – The block’s stride. If greater than 1, this block will ultimately downsample the input.
  • data_format – The input format (‘channels_last’ or ‘channels_first’).
Returns:

The output tensor of the block; shape should match inputs.

encoders.resnet_blocks.bottleneck_block_v2(inputs, filters, training, projection_shortcut, strides, data_format, regularizer, bn_regularizer, bn_momentum, bn_epsilon)[source]

A single block for ResNet v2, without a bottleneck.

Similar to _building_block_v2(), except using the “bottleneck” blocks described in:

Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition https://arxiv.org/pdf/1512.03385.pdf by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
Adapted to the ordering conventions of:
Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks https://arxiv.org/pdf/1603.05027.pdf by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
Parameters:
  • inputs – A tensor of size [batch, channels, height_in, width_in] or [batch, height_in, width_in, channels] depending on data_format.
  • filters – The number of filters for the convolutions.
  • training – A Boolean for whether the model is in training or inference mode. Needed for batch normalization.
  • projection_shortcut – The function to use for projection shortcuts (typically a 1x1 convolution when downsampling the input).
  • strides – The block’s stride. If greater than 1, this block will ultimately downsample the input.
  • data_format – The input format (‘channels_last’ or ‘channels_first’).
Returns:

The output tensor of the block; shape should match inputs.

encoders.resnet_blocks.building_block_v1(inputs, filters, training, projection_shortcut, strides, data_format, regularizer, bn_regularizer, bn_momentum, bn_epsilon)[source]

A single block for ResNet v1, without a bottleneck.

Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition https://arxiv.org/pdf/1512.03385.pdf by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
Parameters:
  • inputs – A tensor of size [batch, channels, height_in, width_in] or [batch, height_in, width_in, channels] depending on data_format.
  • filters – The number of filters for the convolutions.
  • training – A Boolean for whether the model is in training or inference mode. Needed for batch normalization.
  • projection_shortcut – The function to use for projection shortcuts (typically a 1x1 convolution when downsampling the input).
  • strides – The block’s stride. If greater than 1, this block will ultimately downsample the input.
  • data_format – The input format (‘channels_last’ or ‘channels_first’).
Returns:

The output tensor of the block; shape should match inputs.

encoders.resnet_blocks.building_block_v2(inputs, filters, training, projection_shortcut, strides, data_format, regularizer, bn_regularizer, bn_momentum, bn_epsilon)[source]

A single block for ResNet v2, without a bottleneck.

Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks https://arxiv.org/pdf/1603.05027.pdf by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
Parameters:
  • inputs – A tensor of size [batch, channels, height_in, width_in] or [batch, height_in, width_in, channels] depending on data_format.
  • filters – The number of filters for the convolutions.
  • training – A Boolean for whether the model is in training or inference mode. Needed for batch normalization.
  • projection_shortcut – The function to use for projection shortcuts (typically a 1x1 convolution when downsampling the input).
  • strides – The block’s stride. If greater than 1, this block will ultimately downsample the input.
  • data_format – The input format (‘channels_last’ or ‘channels_first’).
Returns:

The output tensor of the block; shape should match inputs.

encoders.resnet_blocks.conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format, regularizer)[source]

Strided 2-D convolution with explicit padding.

encoders.resnet_blocks.fixed_padding(inputs, kernel_size, data_format)[source]

Pads the input along the spatial dimensions independently of input size.

Parameters:
  • inputs – A tensor of size [batch, channels, height_in, width_in] or [batch, height_in, width_in, channels] depending on data_format.
  • kernel_size – The kernel to be used in the conv2d or max_pool2d operation. Should be a positive integer.
  • data_format – The input format (‘channels_last’ or ‘channels_first’).
Returns:

A tensor with the same format as the input with the data either intact (if kernel_size == 1) or padded (if kernel_size > 1).

cnn_encoder

This module contains classes and functions to build “general” convolutional neural networks from the description of arbitrary “layers”.

class encoders.cnn_encoder.CNNEncoder(params, model, name='cnn_encoder', mode='train')[source]

Bases: encoders.encoder.Encoder

General CNN encoder that can be used to construct various different models.

__init__(params, model, name='cnn_encoder', mode='train')[source]

CNN Encoder constructor.

See parent class for arguments description.

Config parameters:

  • cnn_layers (list) — list with the description of “convolutional” layers. For example:

    "conv_layers": [
        (tf.layers.conv2d, {
            'filters': 64, 'kernel_size': (11, 11),
            'strides': (4, 4), 'padding': 'VALID',
            'activation': tf.nn.relu,
        }),
        (tf.layers.max_pooling2d, {
            'pool_size': (3, 3), 'strides': (2, 2),
        }),
        (tf.layers.conv2d, {
            'filters': 192, 'kernel_size': (5, 5),
            'strides': (1, 1), 'padding': 'SAME',
        }),
        (tf.layers.batch_normalization, {'momentum': 0.9, 'epsilon': 0.0001}),
        (tf.nn.relu, {}),
    ]
    

    Note that you don’t need to provide “regularizer”, “training”, “data_format” and “axis” parameters since they will be automatically added. “axis” will be derived from “data_format” and will be 1 if data_format == "channels_first" else 3.

  • fc_layers (list) — list with the description of “fully-connected” layers. The only different from convolutional layers is that the input will be automatically reshaped to 2D (batch size x num features). For example:

    'fc_layers': [
        (tf.layers.dense, {'units': 4096, 'activation': tf.nn.relu}),
        (tf.layers.dropout, {'rate': 0.5}),
        (tf.layers.dense, {'units': 4096, 'activation': tf.nn.relu}),
        (tf.layers.dropout, {'rate': 0.5}),
    ],
    

    Note that you don’t need to provide “regularizer”, “training”, “data_format” and “axis” parameters since they will be automatically added. “axis” will be derived from “data_format” and will be 1 if data_format == "channels_first" else 3.

  • data_format (string) — could be either “channels_first” or “channels_last”. Defaults to “channels_first”.

static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict
encoders.cnn_encoder.build_layer(inputs, layer, layer_params, data_format, regularizer, training, verbose=True)[source]

This function builds a layer from the layer function and it’s parameters.

It will automatically add regularizer parameter to the layer_params if the layer supports regularization. To check this, it will look for the “regularizer”, “kernel_regularizer” and “gamma_regularizer” names in this order in the layer call signature. If one of this parameters is supported it will pass regularizer object as a value for that parameter. Based on the same “checking signature” technique “data_format” and “training” parameters will try to be added. Finally, “axis” parameter will try to be specified with axis = 1 if data_format == 'channels_first' else 3. This is required for automatic building batch normalization layer.

Parameters:
  • inputs – input Tensor that will be passed to the layer. Note that layer has to accept input as the first parameter.
  • layer – layer function or class with __call__ method defined.
  • layer_params (dict) – parameters passed to the layer.
  • data_format (string) – data format (“channels_first” or “channels_last”) that will be tried to be passed as an additional argument.
  • regularizer – regularizer instance that will be tried to be passed as an additional argument.
  • training (bool) – whether layer is built in training mode. Will be tried to be passed as an additional argument.
  • verbose (bool) – whether to print information about built layers.
Returns:

Tensor with layer output.

tacotron2_encoder

class encoders.tacotron2_encoder.Tacotron2Encoder(params, model, name='tacotron2_encoder', mode='train')[source]

Bases: encoders.encoder.Encoder

Tacotron-2 like encoder.

Consists of an embedding layer followed by a convolutional layer followed by a recurrent layer.

__init__(params, model, name='tacotron2_encoder', mode='train')[source]

Tacotron-2 like encoder constructor.

See parent class for arguments description.

Config parameters:

  • cnn_dropout_prob (float) — dropout probabilty for cnn layers.

  • rnn_dropout_prob (float) — dropout probabilty for cnn layers.

  • src_emb_size (int) — dimensionality of character embedding.

  • conv_layers (list) — list with the description of convolutional layers. For example:

    "conv_layers": [
      {
        "kernel_size": [5], "stride": [1],
        "num_channels": 512, "padding": "SAME"
      },
      {
        "kernel_size": [5], "stride": [1],
        "num_channels": 512, "padding": "SAME"
      },
      {
        "kernel_size": [5], "stride": [1],
        "num_channels": 512, "padding": "SAME"
      }
    ]
    
  • activation_fn (callable) — activation function to use for conv layers.

  • num_rnn_layers — number of RNN layers to use.

  • rnn_cell_dim (int) — dimension of RNN cells.

  • rnn_type (callable) — Any valid RNN Cell class. Suggested class is lstm

  • rnn_unidirectional (bool) — whether to use uni-directional or bi-directional RNNs.

  • zoneout_prob (float) — zoneout probability. Defaults to 0.

  • use_cudnn_rnn (bool) — need to be enabled in rnn_type is a Cudnn class.

  • data_format (string) — could be either “channels_first” or “channels_last”. Defaults to “channels_last”.

  • bn_momentum (float) — momentum for batch norm. Defaults to 0.1.

  • bn_epsilon (float) — epsilon for batch norm. Defaults to 1e-5.

  • style_embedding_enable (bool) — Whether to enable GST. Defaults to False.

  • style_embedding_params (dict) — Parameters for GST layer. See _embed_style documentation.

_embed_style(style_spec, style_len)[source]

Code that implements the reference encoder as described in “Towards end-to-end prosody transfer for expressive speech synthesis with Tacotron”, and “Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis”

Config parameters:

  • conv_layers (list) — See the conv_layers parameter for the Tacotron-2 model.
  • num_rnn_layers (int) — Number of rnn layers in the reference encoder
  • rnn_cell_dim (int) — Size of rnn layer
  • rnn_unidirectional (bool) — Uni- or bi-directional rnn.
  • rnn_type — Must be a valid tf rnn cell class
  • emb_size (int) — Size of gst
  • attention_layer_size (int) — Size of linear layers in attention
  • num_tokens (int) — Number of tokens for gst
  • num_heads (int) — Number of attention heads
_encode(input_dict)[source]

Creates TensorFlow graph for Tacotron-2 like encoder.

Parameters:input_dict (dict) –

dictionary with inputs. Must define:

source_tensors - array containing [
  • source_sequence: tensor of shape [batch_size, sequence length]
  • src_length: tensor of shape [batch_size]

]

Returns:A python dictionary containing:
  • outputs - tensor containing the encoded text to be passed to the attention layer
  • src_length - the length of the encoded text
Return type:dict
static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict

wavenet_encoder

class encoders.wavenet_encoder.WavenetEncoder(params, model, name='wavenet_encoder', mode='train')[source]

Bases: encoders.encoder.Encoder

WaveNet like encoder.

Consists of several blocks of dilated causal convolutions.

__init__(params, model, name='wavenet_encoder', mode='train')[source]

WaveNet like encoder constructor.

Config parameters: * layer_type (str) — type of convolutional layer, currently only

supports “conv1d”
  • kernel_size (int) — size of kernel
  • strides (int) — size of stride
  • padding (str) — padding, can be “SAME” or “VALID”
  • blocks (int) — number of dilation cycles
  • layers_per_block (int) — number of dilated convolutional layers in each block
  • filters (int) — number of output channels
  • quantization_channels (int) — depth of mu-law quantized input
  • data_format (string) — could be either “channels_first” or “channels_last”. Defaults to “channels_last”.
  • bn_momentum (float) — momentum for batch norm. Defaults to 0.1.
  • bn_epsilon (float) — epsilon for batch norm. Defaults to 1e-5.
_encode(input_dict)[source]

Creates TensorFlow graph for WaveNet like encoder. …

static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict
encoders.wavenet_encoder.causal_conv_bn_actv(layer_type, name, inputs, filters, kernel_size, activation_fn, strides, padding, regularizer, training, data_format, bn_momentum, bn_epsilon, dilation=1)[source]

Defines a single dilated causal convolutional layer with batch norm

encoders.wavenet_encoder.conv_1x1(layer_type, name, inputs, filters, strides, regularizer, training, data_format)[source]

Defines a single 1x1 convolution for convenience

encoders.wavenet_encoder.wavenet_conv_block(layer_type, name, inputs, condition_filter, condition_gate, filters, kernel_size, strides, padding, regularizer, training, data_format, bn_momentum, bn_epsilon, layers_per_block)[source]

Defines a single WaveNet block using the architecture specified in the original paper, including skip and residual connections

centaur_encoder

class encoders.centaur_encoder.CentaurEncoder(params, model, name='centaur_encoder', mode='train')[source]

Bases: open_seq2seq.encoders.encoder.Encoder

Centaur encoder that consists of convolutional layers.

__init__(params, model, name='centaur_encoder', mode='train')[source]

Centaur encoder constructor.

See parent class for arguments description.

Config parameters:

  • src_vocab_size (int) — number of symbols in alphabet.

  • embedding_size (int) — dimensionality of character embedding.

  • output_size (int) — dimensionality of output embedding.

  • conv_layers (list) — list with the description of convolutional layers. For example:

    "conv_layers": [
      {
        "kernel_size": [5], "stride": [1],
        "num_channels": 512, "padding": "SAME"
      },
      {
        "kernel_size": [5], "stride": [1],
        "num_channels": 512, "padding": "SAME"
      },
      {
        "kernel_size": [5], "stride": [1],
        "num_channels": 512, "padding": "SAME"
      }
    ]
    
  • bn_momentum (float) — momentum for batch norm. Defaults to 0.95.

  • bn_epsilon (float) — epsilon for batch norm. Defaults to 1e-8.

  • cnn_dropout_prob (float) — dropout probabilty for cnn layers. Defaults to 0.5.

static get_optional_params()[source]

Static method with description of optional parameters.

Returns:Dictionary containing all the parameters that can be included into the params parameter of the class __init__() method.
Return type:dict
static get_required_params()[source]

Static method with description of required parameters.

Returns:Dictionary containing all the parameters that have to be included into the params parameter of the class __init__() method.
Return type:dict