pyTorch¶
- class transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)¶
Applies a linear transformation to the incoming data \(y = xA^T + b\)
On NVIDIA GPUs it is a drop-in replacement for torch.nn.Linear.
- Parameters
in_features (int) – size of each input sample.
out_features (int) – size of each output sample.
bias (bool, default = True) – if set to False, the layer will not learn an additive bias.
init_method (Callable, default = None) – used for initializing weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).
parameters_split (Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None) – if a tuple of strings or a dict of strings to integers is provided, the weight and bias parameters of the module are exposed as N separate torch.nn.parameter.Parameter`s each, split along the first dimension, where `N is the length of the argument and the strings contained are the names of the split parameters. In the case of a tuple, each parameter has the same shape. In the case of a dict, the values give the out_features for each projection.
device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
- Parallelism parameters
sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.
tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.
parallel_mode ({None, ‘Column’, ‘Row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.
- Optimization parameters
fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.
return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.
params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
- forward(inp: torch.Tensor, is_first_microbatch: Optional[bool] = None)¶
Apply the linear transformation to the input.
- Parameters
inp (torch.Tensor) – Input tensor.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
- set_tensor_parallel_group(tp_group: Union[transformer_engine.pytorch.constants.dist_group_type, None])¶
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)¶
Applies Layer Normalization over a mini-batch of inputs as described in the paper Layer Normalization
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta\]\(\gamma\) and \(\beta\) are learnable affine transform parameters of size
hidden_size
- Parameters
hidden_size (int) – size of each input sample.
eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.
params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
- class transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)¶
Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in the paper Root Mean Square Layer Normalization
\[y = \frac{x}{RMS_\varepsilon(x)} * \gamma\]where
\[RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon}\]\(\gamma\) is a learnable affine transform parameter of size
hidden_size
- Parameters
hidden_size (int) – size of each input sample.
eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.
params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in RMSNorm is initialized to 0 and the RMSNorm formula changes to
\[y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma)\]device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
- class transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)¶
Applies layer normalization followed by linear transformation to the incoming data.
- Parameters
in_features (int) – size of each input sample.
out_features (int) – size of each output sample.
eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
bias (bool, default = True) – if set to False, the layer will not learn an additive bias.
normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.
init_method (Callable, default = None) – used for initializing weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).
return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.
parameters_split (Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None) – if a tuple of strings or a dict of strings to integers is provided, the weight and bias parameters of the module are exposed as N separate torch.nn.parameter.Parameter`s each, split along the first dimension, where `N is the length of the argument and the strings contained are the names of the split parameters. In the case of a tuple, each parameter has the same shape. In the case of a dict, the values give the out_features for each projection.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
- Parallelism parameters
sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.
tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.
parallel_mode ({None, ‘Column’, ‘Row’}, default = None) – used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described here. When set to None, no communication is performed.
- Optimization parameters
fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.
return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.
params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
- forward(inp: torch.Tensor, is_first_microbatch: Optional[bool] = None)¶
Apply layer normalization to the input followed by a linear transformation.
- Parameters
inp (torch.Tensor) – Input tensor.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
- set_tensor_parallel_group(tp_group: Union[transformer_engine.pytorch.constants.dist_group_type, None])¶
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)¶
Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by the GeLU activation.
- Parameters
hidden_size (int) – size of each input sample.
ffn_hidden_size (int) – intermediate size to which input samples are projected.
eps (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
bias (bool, default = True) – if set to False, the FC1 and FC2 layers will not learn an additive bias.
normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.
activation (str, default = 'gelu') – activation function used. Options: ‘gelu’, ‘geglu’, ‘relu’, ‘reglu’, ‘squared_relu’, ‘swiglu’.
init_method (Callable, default = None) – used for initializing FC1 weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).
output_layer_init_method (Callable, default = None) – used for initializing FC2 weights in the following way: output_layer_init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).
return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
- Parallelism parameters
set_parallel_mode (bool, default = False) – if set to True, FC1 is used as Column Parallel and FC2 is used as Row Parallel as described here.
sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.
tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.
- Optimization parameters
fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.
return_bias (bool, default = False) – when set to True, this module will not apply the additive bias for FC2, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.
params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
seq_length (int) – sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.
micro_batch_size (int) – batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.
- forward(inp: torch.Tensor, is_first_microbatch: Optional[bool] = None)¶
Apply layer normalization to the input followed by a feedforward network (MLP Block).
- Parameters
inp (torch.Tensor) – Input tensor.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
- set_tensor_parallel_group(tp_group: Union[transformer_engine.pytorch.constants.dist_group_type, None])¶
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)¶
Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.
Note
Argument
attention_mask
in the forward call is only used whenattn_mask_type
includes ‘“padding”’ or “arbitrary”.Warning
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe deterministic behavior at the cost of performance, use FlashAttention version < 2.0.0 and set the environment variable
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
. In order to disable`flash-attn` entirely, setNVTE_FLASH_ATTN=0
.- Parameters
num_attention_heads (int) – number of attention heads in the transformer layer.
kv_channels (int) – number of key-value channels.
num_gqa_groups (Optional[int] = None) – number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the queries. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.
attention_dropout (float, default = 0.0) – dropout probability for the dropout op during multi-head attention.
attn_mask_type (str, default = causal) – type of attention mask passed into softmax operation, options are “no_mask”, “padding”, “causal”, “padding,causal”, “causal,padding”, and “arbitrary”, where “padding,causal” and “causal,padding” are equivalent. This arg can be overridden by
attn_mask_type
in the forward method. It is useful for cases involving compilation/tracing, e.g. ONNX export, and the forward arg is useful for dynamically changing mask types, e.g. a different mask for training and inference. For “no_mask”, no attention mask is applied. For “causal” or the causal mask in “padding,causal”, TransformerEngine calculates and applies an upper triangular mask to the softmax input. No user input is needed. For “padding” or the padding mask in “padding,causal”, users need to provide the locations of padded tokens either viacu_seqlens_q
andcu_seqlens_kv
in the shape of [batch_size + 1] orattention_mask
in the shape [batch_size, 1, 1, max_seq_len]. For the “arbitrary” mask, users need to provide a mask that is broadcastable to the shape of softmax input.attention_type (str, default = self) – type of attention, either “self” and “cross”.
layer_number (int, default = None) – layer number of the current DotProductAttention when multiple such modules are concatenated, for instance in consecutive transformer blocks.
qkv_format (str, default = sbhd) – dimension format for query_layer, key_layer and value_layer, {sbhd, bshd, thd}. s stands for the sequence length, b batch size, h the number of heads, d head size, and t the total number of sequences in a batch, with t = sum(s_i), for i = 0…b-1. sbhd and bshd formats are used for when sequences in a batch are of equal length or padded to equal length, and the thd format is used for when sequences in a batch have different lengths. Please note that these formats do not reflect how tensors query_layer, key_layer, value_layer are laid out in memory. For that, please use _get_qkv_layout to gain the layout information.
- Parallelism parameters
sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.
tp_size (int, default = 1) – tensor parallel world size.
tp_group (ProcessGroup, default = None) – tensor parallel process group.
cp_group (ProcessGroup, default = None) – context parallel process group.
cp_global_ranks (list of global rank IDs, default = None) – global rank IDs of GPUs that are in cp_group.
cp_stream (CUDA stream, default = None) – context parallelism splits flash attention into multiple steps for compute and communication overlapping. To address the wave quantization issue of each split step, we add an additional CUDA stream so that we can overlap two flash attention kernels.
- forward(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, qkv_format: Optional[str] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = 'no_bias', core_attention_bias: Optional[torch.Tensor] = None, fast_zero_fill: bool = True)¶
Dot Product Attention Layer.
Note
Argument
attention_mask
is only used whenattn_mask_type
includes ‘“padding”’ or “arbitrary”.Note
Input tensors
query_layer
,key_layer
, andvalue_layer
must each be of shape (sequence_length
,batch_size
,num_attention_heads
,kv_channels
). Output of shape (sequence_length
,batch_size
,num_attention_heads
*kv_channels
) is returned.Note
DotProductAttention supports three backends: 1) FlashAttention which calls HazyResearch/Dao-AILab’s flash-attn PyTorch API, 2) FusedAttention which has multiple fused attention implementations based on cuDNN Graph API (see
FusedAttention
for more details on FusedAttention backends), and 3) UnfusedDotProductAttention which is the native PyTorch implementation with fused scaled masked softmax.Note
Users can use environment variables
NVTE_FLASH_ATTN
,NVTE_FUSED_ATTN
, andNVTE_FUSED_ATTN_BACKEND
to control which DotProductAttention backend, and FusedAttention backend if applicable, to use. TransformerEngine prioritizes FlashAttention over FusedAttention and over UnfusedDotProductAttention. If FusedAttention is being used, users can also choose to switch to flash-attn’s implementation for backward by settingNVTE_FUSED_ATTN_USE_FAv2_BWD=1
(default: 0), because of the performance differences between various versions of flash-attn and FusedAttention. Further,NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT
can be used to enable (1
) or disable (0
) the workspace related optimizations in FusedAttention. When unset, TransformerEngine determines the code path based on its internal logic. These optimizations trade memory for performance and should be used with care.- Parameters
query_layer (torch.Tensor) – Query tensor.
key_layer (torch.Tensor) – Key tensor.
value_layer (torch.Tensor) – Value tensor.
attention_mask (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],) – default = None. Boolean tensor(s) used to mask out attention softmax input. It should be ‘None’ for ‘causal’ and ‘no_mask’ types. For ‘padding’ masks, it should be a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for cross-attention. For the ‘arbitrary’ mask type, it should be in a shape that is broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
qkv_format (str, default = None) – If provided, overrides
qkv_format
from initialization.cu_seqlens_q (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths in a batch for query_layer, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv (Optional[torch.Tensor], default = None) – Cumulative sum of sequence lengths in a batch for key_layer and value_layer, with shape [batch_size + 1] and dtype torch.int32.
attn_mask_type ({no_mask, padding, causal, padding,causal, causal,padding,) – arbitrary}, default = None. Type of attention mask passed into softmax operation. ‘padding,causal’ and ‘causal,padding’ are equivalent.
checkpoint_core_attention (bool, default = False) – If true, forward activations for attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.
core_attention_bias_type (str, default = no_bias) – Bias type, {no_bias, pre_scale_bias, post_scale_bias, alibi}
core_attention_bias (Optional[torch.Tensor], default = None) – Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. It should be ‘None’ for ‘no_bias’ and ‘alibi’ bias types.
fast_zero_fill (bool, default = True) – Whether to use the fast path to set output tensors to 0 or not.
- set_context_parallel_group(cp_group: Union[transformer_engine.pytorch.constants.dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream)¶
Set the context parallel attributes for the given module before executing the forward pass.
- Parameters
cp_group (ProcessGroup) – context parallel process group.
cp_global_ranks (List[int]) – list of global ranks in the context group.
cp_stream (torch.cuda.Stream) – cuda stream for context parallel execution.
- class transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)¶
Multi-head Attention (MHA), including Query, Key, Value and Output projection.
Note
Argument
attention_mask
in the forward call is only used whenattn_mask_type
includes ‘“padding”’ or “arbitrary”.- Parameters
hidden_size (int) – size of each input sample.
num_attention_heads (int) – number of attention heads in the transformer layer.
kv_channels (int, default = None) – number of key-value channels. defaults to
hidden_size
/num_attention_heads
if None.attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.
layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
init_method (Callable, default = None) – used for initializing weights of QKV and FC1 weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).
output_layer_init_method (Callable, default = None) – used for initializing weights of PROJ and FC2 in the following way: output_layer_init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).
layer_number (int, default = None) – layer number of the current TransformerLayer when multiple such modules are concatenated to form a transformer block.
attn_mask_type ({'no_mask', 'padding', 'causal', 'padding_causal' 'arbitrary'},) – default = causal type of attention mask passed into softmax operation. Overridden by
attn_mask_type
in the forward method. The forward arg is useful for dynamically changing mask types, e.g. a different mask for training and inference. The init arg is useful for cases involving compilation/tracing, e.g. ONNX export.num_gqa_groups (int, default = None) –
number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.
return_layernorm_output (bool, default = False) – if set to True, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm.
input_layernorm (bool, default = True) – if set to False, layer normalization to the input is not applied.
attention_type ({ 'self', 'cross' }, default = 'self') – type of attention applied.
zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.
qkv_weight_interleaved (bool, default = True) – if set to False, the QKV weight is interpreted as a concatenation of query, key, and value weights along the 0th dimension. The default interpretation is that the individual q, k, and v weights for each attention head are interleaved. This parameter is set to False when using
fuse_qkv_params=False
.bias (bool, default = True) – if set to False, the transformer layer will not learn any additive biases.
device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
- Parallelism parameters
set_parallel_mode (bool, default = False) – if set to True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.
sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.
tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.
- Optimization parameters
fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.
params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
return_bias (bool, default = False) – when set to True, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation \(y = xA^T\). This is useful when the bias addition can be fused to subsequent operations.
fuse_qkv_params (bool, default = ‘False’) – if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.
- forward(hidden_states: torch.Tensor, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, encoder_output: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: Optional[torch.Tensor] = None, fast_zero_fill: bool = True)¶
Forward propagation for MultiheadAttention layer.
Note
Argument
attention_mask
is only used whenattn_mask_type
includes “padding” or “arbitrary”.- Parameters
hidden_states (torch.Tensor) – Input tensor.
attention_mask (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],) – default = None. Boolean tensor(s) used to mask out attention softmax input. It should be ‘None’ for ‘causal’ and ‘no_mask’ types. For ‘padding’ masks, it should be a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for cross-attention. For the ‘arbitrary’ mask type, it should be in a shape that is broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
attn_mask_type ({'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},) – default = None type of attention mask passed into softmax operation.
encoder_output (Optional[torch.Tensor], default = None) – Output of the encoder block to be fed into the decoder block if using layer_type=”decoder”.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
checkpoint_core_attention (bool, default = False) – If true, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.
rotary_pos_emb (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None) – Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied.
core_attention_bias_type (str, default = no_bias) – Bias type, {no_bias, pre_scale_bias, ‘post_scale_bias`, alibi}
core_attention_bias (Optional[torch.Tensor], default = None) – Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. It should be ‘None’ for ‘no_bias’ and ‘alibi’ bias types.
fast_zero_fill (bool, default = True) – Whether to set output tensors to 0 or not before use.
- set_context_parallel_group(cp_group: Union[transformer_engine.pytorch.constants.dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream)¶
Set the context parallel attributes for the given module before executing the forward pass.
- Parameters
cp_group (ProcessGroup) – context parallel process group.
cp_global_ranks (List[int]) – list of global ranks in the context group.
cp_stream (torch.cuda.Stream) – cuda stream for context parallel execution.
- set_tensor_parallel_group(tp_group: Union[transformer_engine.pytorch.constants.dist_group_type, None])¶
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)¶
TransformerLayer is made up of an attention block and a feedforward network (MLP). This standard layer is based on the paper “Attention Is All You Need”.
Note
Argument
attention_mask
in the forward call is only used whenself_attn_mask_type
includes “padding” or “arbitrary”.- Parameters
hidden_size (int) – size of each input sample.
ffn_hidden_size (int) – intermediate size to which input samples are projected.
num_attention_heads (int) – number of attention heads in the transformer layer.
num_gqa_groups (int, default = None) –
number of GQA groups in the transformer layer. Grouped Query Attention is described in this paper. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (MQA), while GQA-H is equivalent to MHA, i.e. num_gqa_groups = num_attention_heads.
layernorm_epsilon (float, default = 1e-5) – a value added to the denominator of layer normalization for numerical stability.
hidden_dropout (float, default = 0.1) – dropout probability for the dropout op after FC2 layer.
attention_dropout (float, default = 0.1) – dropout probability for the dropout op during multi-head attention.
init_method (Callable, default = None) – used for initializing weights of QKV and FC1 weights in the following way: init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).
output_layer_init_method (Callable, default = None) – used for initializing weights of PROJ and FC2 in the following way: output_layer_init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).
apply_residual_connection_post_layernorm (bool, default = False) – if set to True, residual connections are taken from the output of layer norm (default is taken from input of layer norm)
layer_number (int, default = None) – layer number of the current TransformerLayer when multiple such modules are concatenated to form a transformer block.
output_layernorm (bool, default = False) – if set to True, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation.
parallel_attention_mlp (bool, default = False) – if set to True, self-attention and feedforward network are computed based on the same input (in parallel) instead of sequentially. Both blocks have an independent normalization. This architecture is used in Falcon models.
layer_type ({‘encoder’, ‘decoder’}, default = encoder) – if set to decoder, an additional cross-attn block is added after self-attn. This can be used for structures like T5 Transformer in conjunction with the encoder option.
kv_channels (int, default = None) – number of key-value channels. defaults to
hidden_size
/num_attention_heads
if None.self_attn_mask_type ({'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},) – default = causal type of attention mask passed into softmax operation. Overridden by
self_attn_mask_type
in the forward method. The forward arg is useful for dynamically changing mask types, e.g. a different mask for training and inference. The init arg is useful for cases involving compilation/tracing, e.g. ONNX export.zero_centered_gamma (bool, default = 'False') –
if set to ‘True’, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta\]normalization ({ 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm') – type of normalization applied.
qkv_weight_interleaved (bool, default = True) – if set to False, the QKV weight is interpreted as a concatenation of query, key, and value weights along the 0th dimension. The default interpretation is that the individual q, k, and v weights for each attention head are interleaved. This parameter is set to False when using
fuse_qkv_params=False
.bias (bool, default = True) – if set to False, the transformer layer will not learn any additive biases.
activation (str, default = 'gelu') – Type of activation used in MLP block. Options are: ‘gelu’, ‘relu’, ‘reglu’, ‘geglu’ and ‘swiglu’.
device (Union[torch.device, str], default = "cuda") – The device on which the parameters of the model will allocated. It is the user’s responsibility to ensure all parameters are moved to the GPU before running the forward pass.
- Parallelism parameters
set_parallel_mode (bool, default = False) – if set to True, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described here.
sequence_parallel (bool, default = False) – if set to True, uses sequence parallelism.
tp_group (ProcessGroup, default = None) – tensor parallel process group.
tp_size (int, default = 1) – used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the set_tensor_parallel_group(tp_group) method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives.
- Optimization parameters
fuse_wgrad_accumulation (bool, default = ‘False’) – if set to True, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional main_grad attribute (used instead of the regular grad) which is a pre-allocated buffer of the correct size to accumulate gradients in.
params_dtype (torch.dtype, default = torch.get_default_dtype()) – it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory.
seq_length (int) – sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.
micro_batch_size (int) – batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase.
drop_path_rate (float, default = 0.0) – when > 0.0, applies stochastic depth per sample in the main path of the residual block.
fuse_qkv_params (bool, default = ‘False’) – if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.
- forward(hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[transformer_engine.pytorch.attention.InferenceParams] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, core_attention_bias_type: str = 'no_bias', core_attention_bias: Optional[torch.Tensor] = None, fast_zero_fill: bool = True)¶
Transformer Layer: attention block and a feedforward network (MLP)
Note
Argument
attention_mask
is only used whenself_attn_mask_type
includes “padding” or “arbitrary”.- Parameters
hidden_states (torch.Tensor) – Input tensor.
attention_mask (Optional[torch.Tensor], default = None) – Boolean tensor used to mask out self-attention softmax input. It should be in [batch_size, 1, 1, seqlen_q] for ‘padding’ mask, and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for ‘arbitrary’. It should be ‘None’ for ‘causal’ and ‘no_mask’.
self_attn_mask_type ({'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},) – default = causal Type of attention mask passed into softmax operation.
encoder_output (Optional[torch.Tensor], default = None) – Output of the encoder block to be fed into the decoder block if using layer_type=”decoder”.
enc_dec_attn_mask (Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],) – default = None. Boolean tensors used to mask out inter-attention softmax input if using layer_type=”decoder”. It should be a tuple of two masks in [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for ‘padding’ mask. It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for ‘arbitrary’ mask. It should be ‘None’ for ‘causal’ and ‘no_mask’.
is_first_microbatch ({True, False, None}, default = None) –
During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations:
during FP8 training, it allows caching of the FP8 versions of the weights
it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced)
checkpoint_core_attention (bool, default = False) – If true, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop.
rotary_pos_emb (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None) – Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied.
core_attention_bias_type (str, default = no_bias) – Bias type, {no_bias, pre_scale_bias, post_scale_bias, alibi}
core_attention_bias (Optional[torch.Tensor], default = None) – Bias tensor for Q * K.T
fast_zero_fill (bool, default = True) – Whether to set output tensors to 0 or not before use.
inference_params (InferenceParams, default = None) – Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.
- set_context_parallel_group(cp_group: Union[transformer_engine.pytorch.constants.dist_group_type, None], cp_global_ranks: List[int], cp_stream: torch.cuda.Stream)¶
Set the context parallel attributes for the given module before executing the forward pass.
- Parameters
cp_group (ProcessGroup) – context parallel process group.
cp_global_ranks (List[int]) – list of global ranks in the context group.
cp_stream (torch.cuda.Stream) – cuda stream for context parallel execution.
- set_tensor_parallel_group(tp_group: Union[transformer_engine.pytorch.constants.dist_group_type, None])¶
Set the tensor parallel group for the given module before executing the forward pass.
- Parameters
tp_group (ProcessGroup, default = None) – tensor parallel process group.
- class transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)¶
Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.
- Parameters
max_batch_size (int) – maximum batch size during inference.
max_sequence_length (int) – maximum sequence length during inference.
- class transformer_engine.pytorch.CudaRNGStatesTracker¶
For model parallelism, multiple RNG states need to simultaneously exist in order to execute operations in or out of the model parallel region. This class keeps track of the various RNG states and provides utility methods to maintain them and execute parts of the model under a given RNG setting. Using the add method, a cuda rng state is initialized based on the input seed and is assigned to name. Later, by forking the rng state, we can perform operations and return to our starting cuda state.
- add(name: str, seed: int)¶
Adds a new RNG state.
- name: str
string identifier for the RNG state.
- seed: int
PyTorch seed for the RNG state.
- fork(name: str = 'model-parallel-rng')¶
Fork the cuda rng state, perform operations, and exit with the original state.
- name: str
string identifier for the RNG state.
- get_states()¶
Get rng states. Copy the dictionary so we have direct pointers to the states, not just a pointer to the dictionary.
- reset()¶
Set to the initial state (no tracker).
- set_states(states: Dict[str, torch.Tensor])¶
Set the rng states. For efficiency purposes, we do not check the size of seed for compatibility.
- states: Dict[str, torch.Tensor]
A mapping from string names to RNG states.
- transformer_engine.pytorch.fp8_autocast(enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[transformer_engine.common.recipe.DelayedScaling] = None, fp8_group: Optional[transformer_engine.pytorch.constants.dist_group_type] = None)¶
Context manager for FP8 usage.
with fp8_autocast(enabled=True): out = model(inp)
Note
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16.
Note
When
fp8_recipe.reduce_amax==True
, any module must not be invoked more than once inside a single fp8_autocast region. This is unsupported behavior because the amax reduction is handled during the exit of the fp8_autocast context. Calling the same module more than once inside an fp8_autocast region overrides the amax tensors before reduction can occur.- Parameters
enabled (bool, default = True) – whether or not to enable fp8
calibrating (bool, default = False) – calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision.
fp8_recipe (recipe.DelayedScaling, default = None) – recipe used for FP8 training.
fp8_group (torch._C._distributed_c10d.ProcessGroup, default = None) – distributed group over which amaxes for the fp8 tensors are reduced at the end of each training step.
- transformer_engine.pytorch.fp8_model_init(enabled: bool = True)¶
Context manager for FP8 initialization of parameters.
Example usage:
with fp8_model_init(enabled=True): model = transformer_engine.pytorch.Linear(768, 768)
- Parameters
enabled (bool, default = True) –
when enabled, Transformer Engine modules created inside this fp8_model_init region will hold only FP8 copies of its parameters, as opposed to the default behavior where both higher precision and FP8 copies are present. Setting this option to True may result in lower memory consumption and is especially useful for scenarios like:
full model training using optimizer with master weights, where the high precision copies of weights are already present in the optimizer.
inference, where only the FP8 copies of the parameters are used.
LoRA-like fine-tuning, where the main parameters of the model do not change.
This functionality is EXPERIMENTAL.
- transformer_engine.pytorch.checkpoint(function: Callable, distribute_saved_activations: bool, get_cuda_rng_tracker: Callable, tp_group: transformer_engine.pytorch.constants.dist_group_type, *args: Tuple[torch.Tensor, Ellipsis], **kwargs: Dict[str, Any])¶
Checkpoint a part of the model by trading compute for memory. This function is based on torch.utils.checkpoint.checkpoint.
Warning
It is the user’s responsibility to ensure identical behavior when calling
function
from the forward and backward pass. If different output is produced (e.g. due to global state), then the checkpointed version won’t be numerically equivalent.Warning
The tuple
args
must contain only tensors (orNone
) in order to comply with PyTorch’ssave_for_backward
method.function
must be callable to produce valid outputs with the inputsargs
andkwargs
.- Parameters
function (Callable) – pytorch module used to run the forward and backward passes using the specified
args
andkwargs
.distribute_saved_activations (bool) – if set to True, the first tensor argument is distributed across the specified tensor parallel group (tp_group) before saving it for the backward pass.
get_cuda_rng_tracker (Callable) – python callable which returns an instance of
CudaRNGStatesTracker()
.tp_group (ProcessGroup, default = None) – tensor parallel process group.
args (tuple) – tuple of torch tensors for inputs to
function
.kwargs (dict) – dictionary of string keys for keyword arguments to
function
.
- transformer_engine.pytorch.onnx_export(enabled: bool = False)¶
Context manager for exporting to ONNX.
with onnx_export(enabled=True): torch.onnx.export(model)
- Parameters
enabled (bool, default = False) – whether or not to enable export