paddle¶
- class transformer_engine.paddle.Linear(in_features, out_features, **kwargs)¶
Applies a linear transformation to the incoming data \(y = xA^T + b\)
- Parameters
in_features (int) – size of each input sample.
out_features (int) – size of each output sample.
weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.
bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.
backend ({'transformer_engine', 'paddle'}, default = 'transformer_engine') – if set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.
- Parallelism parameters
tp_group (ProcessGroup, default = None) – tensor parallel process group.
parallel_mode ({None, ‘Column’, ‘Row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.
- forward(*args, **kwargs)¶
Apply the linear transformation to the input.
- Parameters
inp (torch.Tensor) – Input tensor.
- class transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs)¶
Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer Normalization
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta\]\(\gamma\) and \(\beta\) are learnable affine transform parameters of size
hidden_size
- Parameters
hidden_size (int) – size of each input sample.
eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.
bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]backend ({‘transformer_engine’, ‘paddle’}, default = transformer_engine) – backend to use for softmax operation.
- class transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs)¶
Applies layer normalization followed by linear transformation to the incoming data.
- Parameters
in_features (int) – size of each input sample.
out_features (int) – size of each output sample.
eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.
bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.
return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]backend ({'transformer_engine', 'paddle'}, default = 'transformer_engine') – if set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.
- Parallelism parameters
tp_group (ProcessGroup, default = None) – tensor parallel process group.
parallel_mode ({None, ‘Column’, ‘Row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.
- forward(*args, **kwargs)¶
Apply layer normalization to the input followed by a linear transformation.
- Parameters
inp (torch.Tensor) – Input tensor.
- class transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs)¶
Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by the GeLU activation.
- Parameters
hidden_size (int) – size of each input sample.
ffn_hidden_size (int) – intermediate size to which input samples are projected.
eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.
bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.
activation (str, default = 'gelu') – activation function used. Options: ‘gelu’, ‘geglu’, ‘relu’, ‘reglu’, ‘squared_relu’, ‘swiglu’.
return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]backend ({'transformer_engine', 'paddle'}, default = 'transformer_engine') – if set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.
- Parallelism parameters
set_parallel_mode (bool, default = False) – if set to True, FC1 is used as Column Parallel and FC2 is used as Row Parallel as described here.
tp_group (paddle.distributed.collective.Group, default = None) – tensor parallel process group.
- forward(*args, **kwargs)¶
Apply layer normalization to the input followed by a feedforward network (MLP Block).
- Parameters
inp (torch.Tensor) – Input tensor.
- class transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs)¶
Scaled and masked softmax module for paddle with fused optimizations.
- Parameters
attn_mask_type (str, default = causal) – type of attention mask, can be ‘causal’, ‘padding’, or ‘no_mask’.
mask_func (callable) – custom callable for applying the mask to the softmax input. masked_input=mask_func(inp, mask).
softmax_in_fp32 (bool, default = True) – perform softmax computation in fp32.
layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
backend ({‘transformer_engine’, ‘paddle’}, default = transformer_engine) – backend to use for operation.
- forward(inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None)¶
FusedScaleMaskSoftmax fprop
- class transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs)¶
Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.
Note
Argument
attention_mask
will be ignored in the forward call whenattn_mask_type
is set to “causal”.- Parameters
norm_factor (float) – normalization factor for the attention scores.
attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.
attn_mask_type ({‘causal’, ‘padding’, ‘no_mask’}, default = causal) – type of attention mask passed into softmax operation.
attention_type ({‘self’, ‘cross’}, default = self) – type of attention operation.
backend ({‘transformer_engine’, ‘paddle’}, default = transformer_engine) – backend to use for attention operation.
- forward(query_layer: paddle.Tensor, key_value_layer: paddle.Tensor = None, attention_mask: Optional[paddle.Tensor] = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: Optional[paddle.Tensor] = None, set_zero: bool = True)¶
Dot Product Attention Layer.
Note
Argument
attention_mask
will be ignored whenattn_mask_type
is set to “causal”.Note
For self attention,
query_layer
is the [query, key, value] tensor stacked along the 2nd dimension, which must be of shape (batch_size
,seq_length
, 3,num_attention_heads
,size_per_head
). Andkey_value_layer
is None. For cross attention,query_layer
is the [query] tensor, which must be of shape (batch_size
,seq_length
,num_attention_heads
,size_per_head
). Andkey_value_layer
is the [key, value] tensor, which must be of shape (batch_size
,seq_length
, 2,num_attention_heads
,size_per_head
).- Parameters
query_layer (paddle.Tensor) – Query tensor.
key_value_layer (paddle.Tensor) – Key tensor.
attention_mask (Optional[paddle.Tensor], default = None) – Boolean tensor used to mask out softmax input when not using attention.
core_attention_bias_type (str, default = no_bias) – only support no_bias type currently, {no_bias}
core_attention_bias (Optional[paddle.Tensor], default = None) – Bias tensor for Q * K.T
set_zero (bool, default = True) – Whether to use the fast path to set output tensors to 0 or not.
- class transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs)¶
Multi-head Attention (MHA), including Query, Key, Value and Output projection.
- Parameters
hidden_size (int) – hidden size of the model.
num_attention_heads (int) – number of attention heads.
attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.
layernorm_epsilon (float, default = 1e-5) – epsilon to use in the layer norm operations.
weight_attr (Union[paddle.ParamAttr, None], default = None) – paddle.ParamAttr object for the weight parameter.
bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – paddle.ParamAttr object for the bias parameter.
attn_mask_type ({‘causal’, ‘padding’, ‘no_mask’}, default = causal) – type of attention mask passed into softmax operation.
params_dtype (Optional[paddle.dtype], default = None) – data type for the weights and biases.
return_layernorm_output (bool, default = False) – whether to return the output of the layernorm operation.
input_layernorm (bool, default = False) – whether to apply layernorm to the input.
attention_type ({‘self’, ‘cross’}, default = self) – type of attention operation.
zero_centered_gamma (bool, default = False) – whether to zero initialize the gamma of the layernorm operation.
backend ({‘transformer_engine’, ‘paddle’}, default = transformer_engine) – backend to use for attention operation. If set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.
- Parallelism parameters
set_parallel_mode (bool, default = False) – if set to True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.
tp_group (ProcessGroup, default = None) – tensor parallel process group.
rng_state_name (str, default = local_seed) – Controls the rng state used for dropout on attention probs. The specified rng should be set different seeds for different TP ranks. It will be ignored if set_parallel_mode is False. The specified name should be registered through paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() .add(rng_state_name, seed).
- forward(hidden_states: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, encoder_output: Optional[paddle.Tensor] = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: Optional[paddle.Tensor] = None, set_zero: bool = True, recompute_core_attention: bool = False)¶
MultiHeadAttention Layer.
- Parameters
hidden_states (paddle.Tensor) – Input tensor.
attention_mask (Optional[paddle.Tensor], default = None) – Boolean tensor used to mask out softmax input when not using attention.
encoder_output (Optional[paddle.Tensor], default = None) – Output of the encoder layer.
core_attention_bias_type (str, default = no_bias) – only support no_bias type currently, {no_bias}
core_attention_bias (Optional[paddle.Tensor], default = None) – Bias tensor for Q * K.T
set_zero (bool, default = True) – Whether to use the fast path to set output tensors to 0 or not.
recompute_core_attention (bool, default = False) – If true, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.
- class transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)¶
TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”.
- Parameters
hidden_size (int) – size of each input sample.
ffn_hidden_size (int) – intermediate size to which input samples are projected.
num_attention_heads (int) – number of attention heads in the transformer layer.
layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
hidden_dropout (float, default = 0.1) – dropout probability for the dropout op after FC2 layer.
attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.
weight_attr (Union[paddle.ParamAttr, None], default = None) – optional paddle.ParamAttr for weight.
bias_attr (Union[paddle.ParamAttr, None, bool], default = None) – optional paddle.ParamAttr for bias.
self_attn_mask_type ({‘causal’, ‘padding’}, default = causal) – type of attention mask passed into softmax operation.
apply_residual_connection_post_layernorm (bool, default = False) – if set to True, residual connections are taken from the output of layer norm (default is taken from input of layer norm)
output_layernorm (bool, default = False) – if set to True, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation.
layer_type ({‘encoder’, ‘decoder’}, default = encoder) – if set to decoder, an additional cross-attn block is added after self-attn. This can be used for structures like T5 Transformer in conjunction with the encoder option.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]activation (str, default = 'gelu') – Type of activation used in MLP block. Options are: ‘gelu’, ‘relu’, ‘reglu’, ‘geglu’ and ‘swiglu’.
params_dtype (paddle.dtype, default = paddle.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
backend ({'transformer_engine', 'paddle'}, default = 'transformer_engine') – if set to ‘paddle’, a framework only no-FP8 path is executed with limited optimization.
- Parallelism parameters
set_parallel_mode (bool, default = False) – if set to True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.
tp_group (ProcessGroup, default = None) – tensor parallel process group.
attention_dropout_rng_state_name (str, default = local_seed) – Controls the rng state used for dropout on attention probs. The specified rng should be set different seeds for different TP ranks. It will be ignored if set_parallel_mode is False.
hidden_dropout_rng_state_name (str, default = global_seed) – Controls the rng state used for dropout on hidden states. The specified rng should be given the same seeds for different TP ranks. It will be ignored if set_parallel_mode is False. The specified name should be registered through paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() .add(rng_state_name, seed).
- forward(hidden_states: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, encoder_output: Optional[paddle.Tensor] = None, enc_dec_attn_mask: Optional[paddle.Tensor] = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: Optional[paddle.Tensor] = None, set_zero: bool = True, recompute_core_attention: bool = False)¶
Transformer Layer: attention block and a feedforward network (MLP)
Note
Argument
attention_mask
will be ignored whenself_attn_mask_type
is set to “causal”.- Parameters
hidden_states (paddle.Tensor) – Input tensor.
attention_mask (Optional[paddle.Tensor], default = None) – Boolean tensor used to mask out self-attention softmax input.
encoder_output (Optional[paddle.Tensor], default = None) – Output of the encoder block to be fed into the decoder block if using layer_type=”decoder”.
enc_dec_attn_mask (Optional[paddle.Tensor], default = None) – Boolean tensor used to mask out inter-attention softmax input if using layer_type=”decoder”.
core_attention_bias_type (str, default = no_bias) –
core_attention_bias (Optional[paddle.Tensor], default = None) – Bias tensor for Q * K.T
set_zero (bool, default = True) – Whether to set output tensors to 0 or not before use.
recompute_core_attention (bool, default = False) – If true, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.
- transformer_engine.paddle.fp8_autocast(enabled: bool = False, calibrating: bool = False, fp8_recipe: Optional[transformer_engine.common.recipe.DelayedScaling] = None, fp8_group: Optional[transformer_engine.paddle.constants.dist_group_type] = None)¶
Context manager for FP8 usage.
with fp8_autocast(enabled=True): out = model(inp)
Note
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16.
Note
When
fp8_recipe.reduce_amax==True
, any module must not be invoked more than once inside a single fp8_autocast region. This is unsupported behavior because the amax reduction is handled during the exit of the fp8_autocast context. Calling the same module more than once inside an fp8_autocast region overrides the amax tensors before reduction can occur.- Parameters
enabled (bool, default = False) – whether or not to enable fp8
calibrating (bool, default = False) – calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision.
fp8_recipe (recipe.DelayedScaling, default = None) – recipe used for FP8 training.
fp8_group (paddle.distributed.collective.Group, default = None) – distributed group over which amaxes for the fp8 tensors are reduced at the end of each training step.
- transformer_engine.paddle.recompute(function, *args, **kwargs)¶
This is a wrapper of paddle.distributed.fleet.utils.recompute. It provides necessary state information for fp8 layers.
- Parameters
function (Callable) – paddle module used to run the forward and backward passes using the specified
args
andkwargs
.args (tuple) – tuple of torch tensors for inputs to
function
.kwargs (dict) – dictionary of string keys for keyword arguments to
function
.