sage_attention

SageAttention-style attention quantization for diffusers models.

apply_sage_attention patches a diffusers transformer to quantize the post-softmax P tile to NVFP4 E2M1 inside ModelOpt’s Triton flash-attention kernel (quantize_p=True). This is purely a quantization feature — it is independent of, and can be freely combined with, the sparse attention methods in modelopt.torch.sparsity.attention_sparsity.

Design

SageAttention wraps the transformer’s forward once:

  1. Before the forward, it sets quantize_p=True in a thread-local store that the Triton kernel reads.

  2. It activates the modelopt_triton diffusers attention backend for the duration of the forward pass so that attention calls are routed to the ModelOpt Triton kernel.

  3. After the forward (finally block), it resets quantize_p=False.

Sparse attention methods (skip-softmax / N:M sparse softmax) manage their own thread-local params (threshold, sparsity_n/m, …) and deliberately do not touch quantize_p, enabling transparent combination:

import modelopt.torch.sparsity.attention_sparsity as mtsa
from modelopt.torch.quantization import apply_sage_attention

# SageAttention standalone — NVFP4 P-matrix quantization only
apply_sage_attention(transformer)

# Combined with N:M sparse softmax
mtsa.sparsify(transformer, mtsa.SPARSE_SOFTMAX_DEFAULT)
apply_sage_attention(transformer)

# Combined with skip-softmax tile pruning
mtsa.sparsify(transformer, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT)
apply_sage_attention(transformer)

Supported models

Currently targets diffusers transformer models (WAN, LTX, …) that use the diffusers attention-dispatch mechanism. The modelopt_triton backend is registered in diffusers._AttentionBackendRegistry on first call.

Requirements

  • CUDA GPU + Triton installed

  • modelopt.torch.sparsity.attention_sparsity (provides the Triton kernel and diffusers backend registration)

Functions

apply_sage_attention

Patch a diffusers transformer to use NVFP4 P-matrix quantization.

apply_sage_attention(transformer, quantize_p=True)

Patch a diffusers transformer to use NVFP4 P-matrix quantization.

Wraps transformer.forward so that every call activates the modelopt_triton diffusers attention backend with quantize_p=True inside the Triton flash-attention kernel.

This is a standalone quantization feature and does not depend on or conflict with mtsa.sparsify(). Both can be applied to the same transformer — sparsity parameters and quantization parameters are stored in independent thread-local slots.

Parameters:
  • transformer (Module) – A diffusers transformer module (e.g. pipe.transformer for WAN2.2 / LTX Video).

  • quantize_p (bool) – If True (default), quantize the post-softmax P tile to NVFP4 E2M1 with per-tile max scaling inside the Triton kernel.

Raises:

ImportError – If modelopt.torch.sparsity.attention_sparsity is not installed (required for the Triton kernel and diffusers backend).

Return type:

None