Transformer Engine
1.2.0dev-e10997b
Version select:
Current release
Older releases
Home
Getting Started
Installation
Prerequisites
Transformer Engine in NGC Containers
pip - from GitHub
Additional Prerequisites
Installation (stable release)
Installation (development build)
Installation (from source)
Getting Started
Overview
Let’s build a Transformer layer!
Meet Transformer Engine
Fused TE Modules
Enabling FP8
Python API documentation
Common API
Format
DelayedScaling
Framework-specific API
pyTorch
Linear
forward
set_tensor_parallel_group
LayerNorm
RMSNorm
LayerNormLinear
forward
set_tensor_parallel_group
LayerNormMLP
forward
set_tensor_parallel_group
DotProductAttention
forward
set_context_parallel_group
MultiheadAttention
forward
set_context_parallel_group
set_tensor_parallel_group
TransformerLayer
forward
set_context_parallel_group
set_tensor_parallel_group
InferenceParams
CudaRNGStatesTracker
add
fork
get_states
reset
set_states
fp8_autocast
fp8_model_init
checkpoint
onnx_export
Jax
MajorShardingType
ShardingType
TransformerLayerType
ShardingResource
fp8_autocast
update_collections
update_fp8_metas
LayerNorm
__call__
DenseGeneral
__call__
LayerNormDenseGeneral
__call__
LayerNormMLP
__call__
RelativePositionBiases
__call__
MultiHeadAttention
__call__
TransformerLayer
__call__
extend_logical_axis_rules
paddle
Linear
forward
LayerNorm
LayerNormLinear
forward
LayerNormMLP
forward
FusedScaleMaskSoftmax
forward
DotProductAttention
forward
MultiHeadAttention
forward
TransformerLayer
forward
fp8_autocast
recompute
Examples and Tutorials
Using FP8 with Transformer Engine
Introduction to FP8
Structure
Mixed precision training - a quick introduction
Mixed precision training with FP8
Using FP8 with Transformer Engine
FP8 recipe
FP8 autocasting
Handling backward pass
Precision
Performance Optimizations
Multi-GPU training
Gradient accumulation fusion
FP8 weight caching
Advanced
C/C++ API
activation.h
void nvte_gelu
void nvte_dgelu
void nvte_geglu
void nvte_dgeglu
void nvte_relu
void nvte_drelu
void nvte_swiglu
void nvte_dswiglu
void nvte_reglu
void nvte_dreglu
cast.h
void nvte_fp8_quantize
void nvte_fp8_dequantize
gemm.h
void nvte_cublas_gemm
void nvte_cublas_atomic_gemm
fused_attn.h
enum NVTE_QKV_Layout
enumerator NVTE_SB3HD
enumerator NVTE_SBH3D
enumerator NVTE_SBHD_SB2HD
enumerator NVTE_SBHD_SBH2D
enumerator NVTE_SBHD_SBHD_SBHD
enumerator NVTE_BS3HD
enumerator NVTE_BSH3D
enumerator NVTE_BSHD_BS2HD
enumerator NVTE_BSHD_BSH2D
enumerator NVTE_BSHD_BSHD_BSHD
enumerator NVTE_T3HD
enumerator NVTE_TH3D
enumerator NVTE_THD_T2HD
enumerator NVTE_THD_TH2D
enumerator NVTE_THD_THD_THD
enum NVTE_QKV_Layout_Group
enumerator NVTE_3HD
enumerator NVTE_H3D
enumerator NVTE_HD_2HD
enumerator NVTE_HD_H2D
enumerator NVTE_HD_HD_HD
enum NVTE_QKV_Format
enumerator NVTE_SBHD
enumerator NVTE_BSHD
enumerator NVTE_THD
enum NVTE_Bias_Type
enumerator NVTE_NO_BIAS
enumerator NVTE_PRE_SCALE_BIAS
enumerator NVTE_POST_SCALE_BIAS
enumerator NVTE_ALIBI
enum NVTE_Mask_Type
enumerator NVTE_NO_MASK
enumerator NVTE_PADDING_MASK
enumerator NVTE_CAUSAL_MASK
enumerator NVTE_PADDING_CAUSAL_MASK
enum NVTE_Fused_Attn_Backend
enumerator NVTE_No_Backend
enumerator NVTE_F16_max512_seqlen
enumerator NVTE_F16_arbitrary_seqlen
enumerator NVTE_FP8
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group
NVTE_QKV_Format nvte_get_qkv_format
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend
void nvte_fused_attn_fwd_qkvpacked
void nvte_fused_attn_bwd_qkvpacked
void nvte_fused_attn_fwd_kvpacked
void nvte_fused_attn_bwd_kvpacked
void nvte_fused_attn_fwd
void nvte_fused_attn_bwd
layer_norm.h
void nvte_layernorm_fwd
void nvte_layernorm1p_fwd
void nvte_layernorm_bwd
void nvte_layernorm1p_bwd
rmsnorm.h
void nvte_rmsnorm_fwd
void nvte_rmsnorm_bwd
softmax.h
void nvte_scaled_softmax_forward
void nvte_scaled_softmax_backward
void nvte_scaled_masked_softmax_forward
void nvte_scaled_masked_softmax_backward
void nvte_scaled_upper_triang_masked_softmax_forward
void nvte_scaled_upper_triang_masked_softmax_backward
transformer_engine.h
typedef void *NVTETensor
enum NVTEDType
enumerator kNVTEByte
enumerator kNVTEInt32
enumerator kNVTEInt64
enumerator kNVTEFloat32
enumerator kNVTEFloat16
enumerator kNVTEBFloat16
enumerator kNVTEFloat8E4M3
enumerator kNVTEFloat8E5M2
enumerator kNVTENumTypes
NVTETensor nvte_create_tensor
void nvte_destroy_tensor
NVTEDType nvte_tensor_type
NVTEShape nvte_tensor_shape
void *nvte_tensor_data
float *nvte_tensor_amax
float *nvte_tensor_scale
float *nvte_tensor_scale_inv
void nvte_tensor_pack_create
void nvte_tensor_pack_destroy
struct NVTEShape
const size_t *data
size_t ndim
struct NVTETensorPack
NVTETensor tensors[MAX_SIZE]
size_t size = 0
static const int MAX_SIZE = 10
namespace transformer_engine
enum class DType
struct TensorWrapper
transpose.h
void nvte_cast_transpose
void nvte_transpose
void nvte_cast_transpose_dbias
void nvte_fp8_transpose_dbias
void nvte_cast_transpose_dbias_dgelu
void nvte_multi_cast_transpose
void nvte_dgeglu_cast_transpose
Transformer Engine
»
Framework-specific API
View page source
Framework-specific API
¶
pyTorch
Linear
forward
set_tensor_parallel_group
LayerNorm
RMSNorm
LayerNormLinear
forward
set_tensor_parallel_group
LayerNormMLP
forward
set_tensor_parallel_group
DotProductAttention
forward
set_context_parallel_group
MultiheadAttention
forward
set_context_parallel_group
set_tensor_parallel_group
TransformerLayer
forward
set_context_parallel_group
set_tensor_parallel_group
InferenceParams
CudaRNGStatesTracker
add
fork
get_states
reset
set_states
fp8_autocast
fp8_model_init
checkpoint
onnx_export
Jax
MajorShardingType
ShardingType
TransformerLayerType
ShardingResource
fp8_autocast
update_collections
update_fp8_metas
LayerNorm
__call__
DenseGeneral
__call__
LayerNormDenseGeneral
__call__
LayerNormMLP
__call__
RelativePositionBiases
__call__
MultiHeadAttention
__call__
TransformerLayer
__call__
extend_logical_axis_rules
paddle
Linear
forward
LayerNorm
LayerNormLinear
forward
LayerNormMLP
forward
FusedScaleMaskSoftmax
forward
DotProductAttention
forward
MultiHeadAttention
forward
TransformerLayer
forward
fp8_autocast
recompute