Common API

class transformer_engine.common.recipe.Format

Supported FP8 formats.

Values
  • E4M3 – All FP8 tensors are in e4m3 format

  • E5M2 – All FP8 tensors are in e5m2 format

  • HYBRID – FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format

class transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.E4M3, amax_history_len=1024, amax_compute_algo='max', scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))

Use the delayed scaling factor strategy. Use scale factor from previous iteration, recompute once every interval, and record amax history of amax_history_len steps.

Parameters
  • margin (int, default = 0) – Margin for the scaling factor computation.

  • interval (int, default = 1) – Controls how often the scaling factor is recomputed.

  • fp8_format ({Format.E4M3, Format.HYBRID}, default = Format.HYBRID) – Controls the FP8 data format used during forward and backward pass.

  • amax_history_len (int, default = 1024) – The length of the amax history window used for scaling factor computation.

  • amax_compute_algo ({'max', 'most_recent', Callable}, default = 'max') –

    Algorithm used for choosing the amax value for the scaling factor computation. There are 2 predefined choices: max chooses the largest amax in the history window, while most_recent always chooses the most recently seen value. Alternatively, one may pass a function of the signature:

    def amax_compute(amax_history: Tensor) -> Tensor
    

    where Tensor is a framework tensor type.

  • scaling_factor_compute_algo (Callable, default = None) –

    Algorithm used for computing the new scaling factor based on the value of amax. It should be a function of the signature:

    def scaling_factor_compute(amax: Tensor,
                               old_scaling_factor: Tensor,
                               fp8_max: Tensor,
                               recipe: DelayedScaling) -> Tensor
    

    where Tensor is a framework tensor type.

  • override_linear_precision (Tuple(bool, bool, bool), default=(False, False, False)) – Whether or not the execute the fprop, dgrad, and wgrad GEMMs (respectively) in higher precision when using FP8.

  • reduce_amax (bool, default = True) – By default, if torch.distributed is initialized, the amax value for FP8 tensors is reduced across the fp8_group (specified in the fp8_autocast call). This keeps the amaxes and scaling factors synced across the given distributed group. If set to False, this reduction is skipped and every GPU maintains local amaxes and scaling factors. To ensure results are numerically identical across checkpointing boundaries in this case, all ranks must checkpoint in order to store the local tensors.

Notes

  • By default (when scaling_factor_compute_algo is left as None) the scaling factor is computed from the final amax value using the formula:

    FP8_MAX = maximum_representable_value(fp8_format)
    new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)