Framework-specific API
- PyTorch
LinearGroupedLinearLayerNormRMSNormLayerNormLinearLayerNormMLPDotProductAttentionMultiheadAttentionTransformerLayerCudaRNGStatesTrackerautocast()quantized_model_init()checkpoint()make_graphed_callables()get_cpu_offload_context()mark_not_offload()ManualOffloadSynchronizerparallel_cross_entropy()- Recipe availability
- Mixture of Experts (MoE) functions
- Communication-computation overlap
- Quantized tensors
- Quantizers
- Tensor saving and restoring functions
- Operation fuser
SequentialFusibleOperationBasicOperationFusedOperationregister_forward_fusion()register_backward_fusion()LinearAddExtraInputAllGatherAllReduceBasicLinearBiasClampedSwiGLUConstantScaleDropoutGEGLUGELUGLUGroupedLinearIdentityL2NormalizationLayerNormMakeExtraOutputQGELUQGEGLUQuantizeReGLUReLUReduceScatterReshapeRMSNormSReGLUSReLUScaledSwiGLUSiLUSwiGLU
- Deprecated functions
- Jax
- Pre-defined Variable of Logical Axes
- Checkpointing
- Modules