Transformer Engine
1.2.0dev-e10997b
Version select:
Current release
Older releases
Home
Getting Started
Installation
Prerequisites
Transformer Engine in NGC Containers
pip - from GitHub
Additional Prerequisites
Installation (stable release)
Installation (development build)
Installation (from source)
Getting Started
Overview
Let’s build a Transformer layer!
Meet Transformer Engine
Fused TE Modules
Enabling FP8
Python API documentation
Common API
Format
DelayedScaling
Framework-specific API
pyTorch
Linear
forward
set_tensor_parallel_group
LayerNorm
RMSNorm
LayerNormLinear
forward
set_tensor_parallel_group
LayerNormMLP
forward
set_tensor_parallel_group
DotProductAttention
forward
set_context_parallel_group
MultiheadAttention
forward
set_context_parallel_group
set_tensor_parallel_group
TransformerLayer
forward
set_context_parallel_group
set_tensor_parallel_group
InferenceParams
CudaRNGStatesTracker
add
fork
get_states
reset
set_states
fp8_autocast
fp8_model_init
checkpoint
onnx_export
Jax
MajorShardingType
ShardingType
TransformerLayerType
ShardingResource
fp8_autocast
update_collections
update_fp8_metas
LayerNorm
__call__
DenseGeneral
__call__
LayerNormDenseGeneral
__call__
LayerNormMLP
__call__
RelativePositionBiases
__call__
MultiHeadAttention
__call__
TransformerLayer
__call__
extend_logical_axis_rules
paddle
Linear
forward
LayerNorm
LayerNormLinear
forward
LayerNormMLP
forward
FusedScaleMaskSoftmax
forward
DotProductAttention
forward
MultiHeadAttention
forward
TransformerLayer
forward
fp8_autocast
recompute
Examples and Tutorials
Using FP8 with Transformer Engine
Introduction to FP8
Structure
Mixed precision training - a quick introduction
Mixed precision training with FP8
Using FP8 with Transformer Engine
FP8 recipe
FP8 autocasting
Handling backward pass
Precision
Performance Optimizations
Multi-GPU training
Gradient accumulation fusion
FP8 weight caching
Advanced
C/C++ API
activation.h
void nvte_gelu
void nvte_dgelu
void nvte_geglu
void nvte_dgeglu
void nvte_relu
void nvte_drelu
void nvte_swiglu
void nvte_dswiglu
void nvte_reglu
void nvte_dreglu
cast.h
void nvte_fp8_quantize
void nvte_fp8_dequantize
gemm.h
void nvte_cublas_gemm
void nvte_cublas_atomic_gemm
fused_attn.h
enum NVTE_QKV_Layout
enumerator NVTE_SB3HD
enumerator NVTE_SBH3D
enumerator NVTE_SBHD_SB2HD
enumerator NVTE_SBHD_SBH2D
enumerator NVTE_SBHD_SBHD_SBHD
enumerator NVTE_BS3HD
enumerator NVTE_BSH3D
enumerator NVTE_BSHD_BS2HD
enumerator NVTE_BSHD_BSH2D
enumerator NVTE_BSHD_BSHD_BSHD
enumerator NVTE_T3HD
enumerator NVTE_TH3D
enumerator NVTE_THD_T2HD
enumerator NVTE_THD_TH2D
enumerator NVTE_THD_THD_THD
enum NVTE_QKV_Layout_Group
enumerator NVTE_3HD
enumerator NVTE_H3D
enumerator NVTE_HD_2HD
enumerator NVTE_HD_H2D
enumerator NVTE_HD_HD_HD
enum NVTE_QKV_Format
enumerator NVTE_SBHD
enumerator NVTE_BSHD
enumerator NVTE_THD
enum NVTE_Bias_Type
enumerator NVTE_NO_BIAS
enumerator NVTE_PRE_SCALE_BIAS
enumerator NVTE_POST_SCALE_BIAS
enumerator NVTE_ALIBI
enum NVTE_Mask_Type
enumerator NVTE_NO_MASK
enumerator NVTE_PADDING_MASK
enumerator NVTE_CAUSAL_MASK
enumerator NVTE_PADDING_CAUSAL_MASK
enum NVTE_Fused_Attn_Backend
enumerator NVTE_No_Backend
enumerator NVTE_F16_max512_seqlen
enumerator NVTE_F16_arbitrary_seqlen
enumerator NVTE_FP8
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group
NVTE_QKV_Format nvte_get_qkv_format
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend
void nvte_fused_attn_fwd_qkvpacked
void nvte_fused_attn_bwd_qkvpacked
void nvte_fused_attn_fwd_kvpacked
void nvte_fused_attn_bwd_kvpacked
void nvte_fused_attn_fwd
void nvte_fused_attn_bwd
layer_norm.h
void nvte_layernorm_fwd
void nvte_layernorm1p_fwd
void nvte_layernorm_bwd
void nvte_layernorm1p_bwd
rmsnorm.h
void nvte_rmsnorm_fwd
void nvte_rmsnorm_bwd
softmax.h
void nvte_scaled_softmax_forward
void nvte_scaled_softmax_backward
void nvte_scaled_masked_softmax_forward
void nvte_scaled_masked_softmax_backward
void nvte_scaled_upper_triang_masked_softmax_forward
void nvte_scaled_upper_triang_masked_softmax_backward
transformer_engine.h
typedef void *NVTETensor
enum NVTEDType
enumerator kNVTEByte
enumerator kNVTEInt32
enumerator kNVTEInt64
enumerator kNVTEFloat32
enumerator kNVTEFloat16
enumerator kNVTEBFloat16
enumerator kNVTEFloat8E4M3
enumerator kNVTEFloat8E5M2
enumerator kNVTENumTypes
NVTETensor nvte_create_tensor
void nvte_destroy_tensor
NVTEDType nvte_tensor_type
NVTEShape nvte_tensor_shape
void *nvte_tensor_data
float *nvte_tensor_amax
float *nvte_tensor_scale
float *nvte_tensor_scale_inv
void nvte_tensor_pack_create
void nvte_tensor_pack_destroy
struct NVTEShape
const size_t *data
size_t ndim
struct NVTETensorPack
NVTETensor tensors[MAX_SIZE]
size_t size = 0
static const int MAX_SIZE = 10
namespace transformer_engine
enum class DType
struct TensorWrapper
transpose.h
void nvte_cast_transpose
void nvte_transpose
void nvte_cast_transpose_dbias
void nvte_fp8_transpose_dbias
void nvte_cast_transpose_dbias_dgelu
void nvte_multi_cast_transpose
void nvte_dgeglu_cast_transpose
Transformer Engine
»
Index
Index
_
|
A
|
C
|
D
|
E
|
F
|
G
|
I
|
L
|
M
|
N
|
O
|
R
|
S
|
T
|
U
_
__call__() (transformer_engine.jax.flax.DenseGeneral method)
(transformer_engine.jax.flax.LayerNorm method)
(transformer_engine.jax.flax.LayerNormDenseGeneral method)
(transformer_engine.jax.flax.LayerNormMLP method)
(transformer_engine.jax.flax.MultiHeadAttention method)
(transformer_engine.jax.flax.RelativePositionBiases method)
(transformer_engine.jax.flax.TransformerLayer method)
A
add() (transformer_engine.pytorch.CudaRNGStatesTracker method)
C
checkpoint() (in module transformer_engine.pytorch)
CudaRNGStatesTracker (class in transformer_engine.pytorch)
D
DelayedScaling (class in transformer_engine.common.recipe)
DenseGeneral (class in transformer_engine.jax.flax)
DotProductAttention (class in transformer_engine.paddle)
(class in transformer_engine.pytorch)
E
extend_logical_axis_rules() (in module transformer_engine.jax.flax)
F
fork() (transformer_engine.pytorch.CudaRNGStatesTracker method)
Format (class in transformer_engine.common.recipe)
forward() (transformer_engine.paddle.DotProductAttention method)
(transformer_engine.paddle.FusedScaleMaskSoftmax method)
(transformer_engine.paddle.LayerNormLinear method)
(transformer_engine.paddle.LayerNormMLP method)
(transformer_engine.paddle.Linear method)
(transformer_engine.paddle.MultiHeadAttention method)
(transformer_engine.paddle.TransformerLayer method)
(transformer_engine.pytorch.DotProductAttention method)
(transformer_engine.pytorch.LayerNormLinear method)
(transformer_engine.pytorch.LayerNormMLP method)
(transformer_engine.pytorch.Linear method)
(transformer_engine.pytorch.MultiheadAttention method)
(transformer_engine.pytorch.TransformerLayer method)
fp8_autocast() (in module transformer_engine.jax)
(in module transformer_engine.paddle)
(in module transformer_engine.pytorch)
fp8_model_init() (in module transformer_engine.pytorch)
FusedScaleMaskSoftmax (class in transformer_engine.paddle)
G
get_states() (transformer_engine.pytorch.CudaRNGStatesTracker method)
I
InferenceParams (class in transformer_engine.pytorch)
L
LayerNorm (class in transformer_engine.jax.flax)
(class in transformer_engine.paddle)
(class in transformer_engine.pytorch)
LayerNormDenseGeneral (class in transformer_engine.jax.flax)
LayerNormLinear (class in transformer_engine.paddle)
(class in transformer_engine.pytorch)
LayerNormMLP (class in transformer_engine.jax.flax)
(class in transformer_engine.paddle)
(class in transformer_engine.pytorch)
Linear (class in transformer_engine.paddle)
(class in transformer_engine.pytorch)
M
MajorShardingType (class in transformer_engine.jax)
MultiHeadAttention (class in transformer_engine.jax.flax)
(class in transformer_engine.paddle)
MultiheadAttention (class in transformer_engine.pytorch)
N
NVTE_Bias_Type (C++ enum)
NVTE_Bias_Type::NVTE_ALIBI (C++ enumerator)
NVTE_Bias_Type::NVTE_NO_BIAS (C++ enumerator)
NVTE_Bias_Type::NVTE_POST_SCALE_BIAS (C++ enumerator)
NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS (C++ enumerator)
nvte_cast_transpose (C++ function)
nvte_cast_transpose_dbias (C++ function)
nvte_cast_transpose_dbias_dgelu (C++ function)
nvte_create_tensor (C++ function)
nvte_cublas_atomic_gemm (C++ function)
nvte_cublas_gemm (C++ function)
nvte_destroy_tensor (C++ function)
nvte_dgeglu (C++ function)
nvte_dgeglu_cast_transpose (C++ function)
nvte_dgelu (C++ function)
nvte_dreglu (C++ function)
nvte_drelu (C++ function)
nvte_dswiglu (C++ function)
nvte_fp8_dequantize (C++ function)
nvte_fp8_quantize (C++ function)
nvte_fp8_transpose_dbias (C++ function)
NVTE_Fused_Attn_Backend (C++ enum)
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen (C++ enumerator)
NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen (C++ enumerator)
NVTE_Fused_Attn_Backend::NVTE_FP8 (C++ enumerator)
NVTE_Fused_Attn_Backend::NVTE_No_Backend (C++ enumerator)
nvte_fused_attn_bwd (C++ function)
nvte_fused_attn_bwd_kvpacked (C++ function)
nvte_fused_attn_bwd_qkvpacked (C++ function)
nvte_fused_attn_fwd (C++ function)
nvte_fused_attn_fwd_kvpacked (C++ function)
nvte_fused_attn_fwd_qkvpacked (C++ function)
nvte_geglu (C++ function)
nvte_gelu (C++ function)
nvte_get_fused_attn_backend (C++ function)
nvte_get_qkv_format (C++ function)
nvte_get_qkv_layout_group (C++ function)
nvte_layernorm1p_bwd (C++ function)
nvte_layernorm1p_fwd (C++ function)
nvte_layernorm_bwd (C++ function)
nvte_layernorm_fwd (C++ function)
NVTE_Mask_Type (C++ enum)
NVTE_Mask_Type::NVTE_CAUSAL_MASK (C++ enumerator)
NVTE_Mask_Type::NVTE_NO_MASK (C++ enumerator)
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK (C++ enumerator)
NVTE_Mask_Type::NVTE_PADDING_MASK (C++ enumerator)
nvte_multi_cast_transpose (C++ function)
NVTE_QKV_Format (C++ enum)
NVTE_QKV_Format::NVTE_BSHD (C++ enumerator)
NVTE_QKV_Format::NVTE_SBHD (C++ enumerator)
NVTE_QKV_Format::NVTE_THD (C++ enumerator)
NVTE_QKV_Layout (C++ enum)
NVTE_QKV_Layout::NVTE_BS3HD (C++ enumerator)
NVTE_QKV_Layout::NVTE_BSH3D (C++ enumerator)
NVTE_QKV_Layout::NVTE_BSHD_BS2HD (C++ enumerator)
NVTE_QKV_Layout::NVTE_BSHD_BSH2D (C++ enumerator)
NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD (C++ enumerator)
NVTE_QKV_Layout::NVTE_SB3HD (C++ enumerator)
NVTE_QKV_Layout::NVTE_SBH3D (C++ enumerator)
NVTE_QKV_Layout::NVTE_SBHD_SB2HD (C++ enumerator)
NVTE_QKV_Layout::NVTE_SBHD_SBH2D (C++ enumerator)
NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD (C++ enumerator)
NVTE_QKV_Layout::NVTE_T3HD (C++ enumerator)
NVTE_QKV_Layout::NVTE_TH3D (C++ enumerator)
NVTE_QKV_Layout::NVTE_THD_T2HD (C++ enumerator)
NVTE_QKV_Layout::NVTE_THD_TH2D (C++ enumerator)
NVTE_QKV_Layout::NVTE_THD_THD_THD (C++ enumerator)
NVTE_QKV_Layout_Group (C++ enum)
NVTE_QKV_Layout_Group::NVTE_3HD (C++ enumerator)
NVTE_QKV_Layout_Group::NVTE_H3D (C++ enumerator)
NVTE_QKV_Layout_Group::NVTE_HD_2HD (C++ enumerator)
NVTE_QKV_Layout_Group::NVTE_HD_H2D (C++ enumerator)
NVTE_QKV_Layout_Group::NVTE_HD_HD_HD (C++ enumerator)
nvte_reglu (C++ function)
nvte_relu (C++ function)
nvte_rmsnorm_bwd (C++ function)
nvte_rmsnorm_fwd (C++ function)
nvte_scaled_masked_softmax_backward (C++ function)
nvte_scaled_masked_softmax_forward (C++ function)
nvte_scaled_softmax_backward (C++ function)
nvte_scaled_softmax_forward (C++ function)
nvte_scaled_upper_triang_masked_softmax_backward (C++ function)
nvte_scaled_upper_triang_masked_softmax_forward (C++ function)
nvte_swiglu (C++ function)
nvte_tensor_amax (C++ function)
nvte_tensor_data (C++ function)
nvte_tensor_pack_create (C++ function)
nvte_tensor_pack_destroy (C++ function)
nvte_tensor_scale (C++ function)
nvte_tensor_scale_inv (C++ function)
nvte_tensor_shape (C++ function)
nvte_tensor_type (C++ function)
nvte_transpose (C++ function)
NVTEDType (C++ enum)
NVTEDType::kNVTEBFloat16 (C++ enumerator)
NVTEDType::kNVTEByte (C++ enumerator)
NVTEDType::kNVTEFloat16 (C++ enumerator)
NVTEDType::kNVTEFloat32 (C++ enumerator)
NVTEDType::kNVTEFloat8E4M3 (C++ enumerator)
NVTEDType::kNVTEFloat8E5M2 (C++ enumerator)
NVTEDType::kNVTEInt32 (C++ enumerator)
NVTEDType::kNVTEInt64 (C++ enumerator)
NVTEDType::kNVTENumTypes (C++ enumerator)
NVTEShape (C++ struct)
NVTEShape::data (C++ member)
NVTEShape::ndim (C++ member)
NVTETensor (C++ type)
NVTETensorPack (C++ struct)
NVTETensorPack::MAX_SIZE (C++ member)
NVTETensorPack::size (C++ member)
NVTETensorPack::tensors (C++ member)
O
onnx_export() (in module transformer_engine.pytorch)
R
recompute() (in module transformer_engine.paddle)
RelativePositionBiases (class in transformer_engine.jax.flax)
reset() (transformer_engine.pytorch.CudaRNGStatesTracker method)
RMSNorm (class in transformer_engine.pytorch)
S
set_context_parallel_group() (transformer_engine.pytorch.DotProductAttention method)
(transformer_engine.pytorch.MultiheadAttention method)
(transformer_engine.pytorch.TransformerLayer method)
set_states() (transformer_engine.pytorch.CudaRNGStatesTracker method)
set_tensor_parallel_group() (transformer_engine.pytorch.LayerNormLinear method)
(transformer_engine.pytorch.LayerNormMLP method)
(transformer_engine.pytorch.Linear method)
(transformer_engine.pytorch.MultiheadAttention method)
(transformer_engine.pytorch.TransformerLayer method)
ShardingResource (class in transformer_engine.jax)
ShardingType (class in transformer_engine.jax)
T
transformer_engine (C++ type)
transformer_engine::DType (C++ enum)
transformer_engine::DType::kBFloat16 (C++ enumerator)
transformer_engine::DType::kByte (C++ enumerator)
transformer_engine::DType::kFloat16 (C++ enumerator)
transformer_engine::DType::kFloat32 (C++ enumerator)
transformer_engine::DType::kFloat8E4M3 (C++ enumerator)
transformer_engine::DType::kFloat8E5M2 (C++ enumerator)
transformer_engine::DType::kInt32 (C++ enumerator)
transformer_engine::DType::kInt64 (C++ enumerator)
transformer_engine::DType::kNumTypes (C++ enumerator)
transformer_engine::TensorWrapper (C++ struct)
transformer_engine::TensorWrapper::amax (C++ function)
transformer_engine::TensorWrapper::data (C++ function)
transformer_engine::TensorWrapper::dptr (C++ function)
transformer_engine::TensorWrapper::dtype (C++ function)
transformer_engine::TensorWrapper::operator= (C++ function)
,
[1]
transformer_engine::TensorWrapper::scale (C++ function)
transformer_engine::TensorWrapper::scale_inv (C++ function)
transformer_engine::TensorWrapper::shape (C++ function)
transformer_engine::TensorWrapper::tensor_ (C++ member)
transformer_engine::TensorWrapper::TensorWrapper (C++ function)
,
[1]
,
[2]
,
[3]
,
[4]
transformer_engine::TensorWrapper::~TensorWrapper (C++ function)
TransformerLayer (class in transformer_engine.jax.flax)
(class in transformer_engine.paddle)
(class in transformer_engine.pytorch)
TransformerLayerType (class in transformer_engine.jax.flax)
U
update_collections() (in module transformer_engine.jax)
update_fp8_metas() (in module transformer_engine.jax)