import functools
import math
import os
import weakref
from typing import List, Optional, Union, cast

import torch
from torch import nn

import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm._utils import (get_sm_version, is_sm_100f, nvtx_range,
                                 nvtx_range_debug)
from tensorrt_llm.llmapi.llm_args import SkipSoftmaxAttentionConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping

from ..attention_backend import (AttentionInputType, AttentionMetadata,
                                 FlashInferAttentionMetadata, TrtllmAttention,
                                 TrtllmAttentionMetadata)
from ..attention_backend.interface import (AttentionBackend, AttentionMask,
                                           CustomAttentionMask,
                                           PositionalEmbeddingParams,
                                           PredefinedAttentionMask)
from ..attention_backend.sparse.dsa import (
    DSAtrtllmAttentionMetadata, transform_local_topk_and_prepare_pool_view)
from ..attention_backend.utils import create_attention, get_attention_backend
from ..distributed import (AllReduceParams, HelixAllToAllNative, alltoall_helix,
                           cp_allgather, reducescatter)
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
                     is_torch_compiling, maybe_compiled_cat,
                     maybe_compiled_copy_)
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .multi_stream_utils import maybe_execute_in_parallel
from .rms_norm import RMSNorm
from .rotary_embedding import MRotaryEmbedding, RotaryEmbedding

# Import FlashMLA sparse attention kernel
try:
    from tensorrt_llm.flash_mla import flash_mla_sparse_fwd
except ImportError:
    flash_mla_sparse_fwd = None


def extract_extra_attrs(layer_idx: str, attn_type: str):
    assert attn_type in ["mla", "attn"], "Invalid attention type"
    extra_attrs = get_model_extra_attrs()
    assert extra_attrs is not None, "Model extra attrs is not set"

    metadata_ref = extra_attrs.get("attention_metadata", None)
    assert metadata_ref is not None, "Attention metadata is not set"
    metadata = metadata_ref()
    if attn_type == "mla":
        assert isinstance(
            metadata,
            TrtllmAttentionMetadata,
        )
    else:
        assert isinstance(
            metadata,
            FlashInferAttentionMetadata,
        ) or isinstance(
            metadata,
            TrtllmAttentionMetadata,
        )

    attn_layers = extra_attrs.get(attn_type + "_layers", None)
    assert attn_layers is not None, "Attention layer is not registered"
    attn_layer_ref = attn_layers.get(layer_idx, None)
    assert attn_layer_ref is not None, f"Cannot find attention layer for layer {layer_idx}"
    attn_layer = attn_layer_ref()

    if attn_type == "mla":
        assert isinstance(
            attn_layer,
            MLA), "MLA layer must be a subclass of MLA or an instance of MLA"
    elif attn_type == "attn":
        assert isinstance(
            attn_layer, Attention
        ), "Attention layer must be a subclass of Attention or an instance of Attention"

    return metadata, attn_layer


def create_attn_outputs_impl(q: torch.Tensor, attention_mask: str,
                             layer_idx: str) -> List[torch.Tensor]:
    metadata, attn_layer = extract_extra_attrs(layer_idx, "attn")
    return attn_layer.create_output(q, metadata, attention_mask)


@torch.library.custom_op("trtllm::create_attn_outputs", mutates_args=())
def create_attn_outputs(q: torch.Tensor, attention_mask: str,
                        layer_idx: str) -> List[torch.Tensor]:
    return create_attn_outputs_impl(q, attention_mask, layer_idx)


@create_attn_outputs.register_fake
def _(q, attention_mask, layer_idx):
    return create_attn_outputs_impl(q, attention_mask, layer_idx)


@torch.library.custom_op("trtllm::attn_custom_op_inplace",
                         mutates_args=("output", "output_sf"))
def attn_custom_op_inplace(
    q: torch.Tensor,
    k: Optional[torch.Tensor],
    v: Optional[torch.Tensor],
    attention_mask: str,
    mrope_rotary_cos_sin: Optional[torch.Tensor],
    mrope_position_deltas: Optional[torch.Tensor],
    attention_window_size: Optional[int],
    attention_mask_data: Optional[torch.Tensor],
    attention_sinks: Optional[torch.Tensor],
    layer_idx: str,
    output: torch.Tensor,
    output_sf: Optional[torch.Tensor],
) -> None:
    metadata, attn_layer = extract_extra_attrs(layer_idx, "attn")
    mask = PredefinedAttentionMask(
        attention_mask
    ) if attention_mask != CustomAttentionMask.CUSTOM else CustomAttentionMask(
        attention_mask)
    # NVFP4 output cannot be supported by torch compile for TRTLLM backend.
    attn_layer._attn_impl(q,
                          k,
                          v,
                          metadata,
                          mask,
                          mrope_rotary_cos_sin,
                          mrope_position_deltas,
                          attention_window_size,
                          attention_mask_data,
                          output=output,
                          output_sf=output_sf,
                          attention_sinks=attention_sinks)


def _helix_post_process(
    partial_o: torch.Tensor,
    softmax_stats: torch.Tensor,
    mapping: Mapping,
    num_heads_tp_cp: int,
    value_dim: int,
    aux_stream: Optional[torch.cuda.Stream] = None,
    ln_events: Optional[list] = None,
) -> torch.Tensor:
    """Helix CP post-processing: all-to-all exchange and combine partial
    attention outputs across CP ranks.

    This is shared by both MHA (Attention) and MLA modules.  The only
    dimension that differs between the two callers is *value_dim*
    (``head_dim`` for MHA, ``kv_lora_rank`` for MLA).

    When *aux_stream* and *ln_events* are provided the two
    ``.contiguous()`` calls in the FIFO-v1 path are overlapped on
    separate CUDA streams for better performance.
    """
    if mapping.cp_config.get("use_nccl_for_alltoall", True):
        # NCCL-based implementation using alltoall_helix.
        chunks = []
        for t in [partial_o, softmax_stats]:
            t = t.transpose(1, 0).contiguous()
            chunks.extend(torch.split(t, t.shape[0] // mapping.cp_size))
        gathered = alltoall_helix(chunks, mapping.cp_group)
        gathered = [t.transpose(1, 2).contiguous() for t in gathered]
        return torch.ops.trtllm.helix_post_process(gathered[0], gathered[1],
                                                   1.0)
    else:
        # FIFO-based implementation using MNNVL workspace.
        helix = HelixAllToAllNative.get(mapping)
        num_tokens = partial_o.shape[0]
        cp_size = mapping.cp_size
        fifo_version = mapping.cp_config.get("fifo_version", 2)

        if fifo_version == 1:
            reshape_o = lambda: partial_o.view(
                num_tokens, cp_size, num_heads_tp_cp, value_dim).transpose(
                    1, 2).contiguous()
            reshape_s = lambda: softmax_stats.view(
                num_tokens, cp_size, num_heads_tp_cp, 2).transpose(
                    1, 2).contiguous()

            if aux_stream is not None and ln_events is not None:
                partial_o, softmax_stats = maybe_execute_in_parallel(
                    reshape_o,
                    reshape_s,
                    ln_events[0],
                    ln_events[1],
                    aux_stream,
                )
            else:
                partial_o = reshape_o()
                softmax_stats = reshape_s()

            partial_o_out, softmax_stats_out = helix.alltoall_native(
                partial_o, softmax_stats)
            return torch.ops.trtllm.helix_post_process_native(
                partial_o_out, softmax_stats_out, 1.0, 2)
        else:
            partial_o = partial_o.view(num_tokens, cp_size,
                                       num_heads_tp_cp * value_dim)
            softmax_stats = softmax_stats.view(num_tokens, cp_size,
                                               num_heads_tp_cp * 2)
            partial_o_out, softmax_stats_out = helix.alltoall_native(
                partial_o, softmax_stats)
            gathered_o = partial_o_out.view(num_tokens, cp_size,
                                            num_heads_tp_cp, value_dim)
            gathered_stats = softmax_stats_out.view(num_tokens, cp_size,
                                                    num_heads_tp_cp, 2)
            return torch.ops.trtllm.helix_post_process_native(
                gathered_o, gathered_stats, 1.0, 1)


def _helix_cp_pad(tensor: torch.Tensor, num_tokens: int,
                  cp_size: int) -> tuple[torch.Tensor, int]:
    """Pad tensor along dim-0 so its length is divisible by cp_size."""
    chunk_size = math.ceil(num_tokens / cp_size)
    padded_size = chunk_size * cp_size
    if num_tokens < padded_size:
        tensor = torch.nn.functional.pad(tensor,
                                         (0, 0, 0, padded_size - num_tokens),
                                         mode="constant",
                                         value=0)
    return tensor, chunk_size


def _helix_cp_allgather_input(hidden_states: torch.Tensor,
                              attn_metadata: AttentionMetadata,
                              mapping: Mapping, layer_idx: int) -> torch.Tensor:
    """AllGather hidden states from CP group for layers after the first.

    The first layer already has the full input from the embedding.
    Subsequent layers need to undo the previous layer's reduce-scatter.
    """
    if (mapping.has_cp_helix() and mapping.enable_attention_dp
            and layer_idx > 0):
        hidden_states = cp_allgather(hidden_states, mapping, dim=0)
        hidden_states = hidden_states[:attn_metadata.num_tokens]
    return hidden_states


def _helix_cp_output_projection(
    o_proj: Linear,
    attn_output: torch.Tensor,
    attn_metadata: AttentionMetadata,
    all_reduce_params: Optional[AllReduceParams],
    mapping: Mapping,
    mapping_o: Mapping,
    layer_idx: int,
    lora_params: Optional[dict] = None,
) -> torch.Tensor:
    """Apply output projection with reduce-scatter when Helix CP+DP is active.

    Reduce-scatter sums partial sums across the CP group and scatters the
    result so each CP rank processes a distinct token chunk through the MLP.
    Falls back to the standard AllReduce path otherwise.
    """
    if mapping.has_cp_helix() and mapping.enable_attention_dp:
        attn_output = o_proj(
            attn_output,
            all_reduce_params=AllReduceParams(enable_allreduce=False),
            lora_params=lora_params,
            layer_idx=layer_idx)

        attn_output, _ = _helix_cp_pad(attn_output, attn_metadata.num_tokens,
                                       mapping.cp_size)
        attn_output = reducescatter(attn_output, mapping_o, dim=0)
    else:
        attn_output = o_proj(attn_output,
                             all_reduce_params=all_reduce_params,
                             lora_params=lora_params,
                             layer_idx=layer_idx)

    return attn_output


def maybe_slice_for_helix_cp(tensor: torch.Tensor,
                             attn_metadata: AttentionMetadata,
                             mapping_with_cp: Optional[Mapping],
                             layer_idx: int) -> torch.Tensor:
    """Slice a tensor to this CP rank's chunk after reduce-scatter.

    For the first decoder layer, the residual comes from the embedding and
    has not been through a prior reduce-scatter.  This function slices it
    so it aligns with the reduce-scattered attention output.  For
    subsequent layers the residual already has the correct size, so this
    is a no-op.

    Call this in the decoder layer on the residual *after* the attention
    forward, so that Attention/MLA forward signatures stay unchanged.
    """
    if (mapping_with_cp is not None and mapping_with_cp.has_cp_helix()
            and mapping_with_cp.enable_attention_dp and layer_idx == 0):
        tensor, chunk_size = _helix_cp_pad(tensor, attn_metadata.num_tokens,
                                           mapping_with_cp.cp_size)
        start = mapping_with_cp.cp_rank * chunk_size
        tensor = tensor[start:start + chunk_size]
    return tensor


def maybe_allgather_for_helix_cp(
        hidden_states: torch.Tensor, attn_metadata: AttentionMetadata,
        mapping_with_cp: Optional[Mapping]) -> torch.Tensor:
    """Restore full token count after the last layer's reduce-scatter.

    With Helix CP + Attention DP, each decoder layer's reduce-scatter
    leaves each CP rank with only its chunk of tokens.  This function
    performs an AllGather across the CP group so that the LM head (and
    final norm) see every token.

    Should be called at the end of the model's ``forward()`` method,
    after the decoder layer loop.
    """
    if (mapping_with_cp is not None and mapping_with_cp.has_cp_helix()
            and mapping_with_cp.enable_attention_dp):
        hidden_states = cp_allgather(hidden_states, mapping_with_cp, dim=0)
        hidden_states = hidden_states[:attn_metadata.num_tokens]
    return hidden_states


class Attention(nn.Module):

    def __init__(
        self,
        *,
        hidden_size: int,
        num_attention_heads: int,
        num_key_value_heads: int,
        max_position_embeddings: int,
        bias: bool,
        pos_embd_params: Optional[PositionalEmbeddingParams] = None,
        rope_fusion: Optional[bool] = None,
        layer_idx: Optional[int] = None,
        dtype: torch.dtype = None,
        dense_bias: Optional[bool] = None,
        config: Optional[ModelConfig] = None,
        q_scaling: float = 1.0,
        attention_chunk_size: Optional[int] = None,
        disable_deep_gemm: bool = False,
        attn_output_gate: Optional[bool] = None,
        use_custom_cublas_mm: bool = False,
        reduce_output: bool = True,
        mapping_with_cp: Optional[Mapping] = None,
    ):
        """
        Initialize the Attention module.

        Args:
            hidden_size (int): The size of the hidden dimension.
            num_attention_heads (int): The number of attention heads.
            num_key_value_heads (int): The number of key value heads.
            max_position_embeddings (int): The maximum position embeddings.
            bias (bool): Whether to use bias in the linear layers.
            pos_embd_params (Optional[PositionalEmbeddingParams]): The positional embedding parameters.
            rope_fusion (Optional[bool]): Whether to fuse RoPE into the attention OP and skip applying unfused RoPE. If None, whether to fuse is decided by the capability of the attention backend.
            layer_idx (Optional[int]): The layer index.
            dtype (torch.dtype): The data type.
            dense_bias (Optional[bool]): Whether to use bias in the output projection layer.
            config (Optional[ModelConfig]): The model configuration.
            q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
            attention_chunk_size (Optional[int]): See [Chunked Attention] below.
            disable_deep_gemm (bool): Whether to disable the use of DeepGEMM in Linear layers (currently only matters on SM100 + FP8).
            attn_output_gate (Optional[bool]): Determines whether to use an output gate in the attention Op. If False, the decision is automatically handled by the attention backend based on its capabilities.
            mapping_with_cp (Optional[Mapping]): Override mapping with CP configuration.
        """
        super().__init__()
        self.layer_idx = layer_idx
        self.layer_idx_str = str(layer_idx)

        self.register_to_config = False
        # We only register TRTLLM attention layers to config.
        if config is not None:
            if "attn_layers" not in config.extra_attrs:
                config.extra_attrs["attn_layers"] = {}
            suffix = 0
            # Makes sure there is no duplicate attention layer identifier.
            while self.layer_idx_str in config.extra_attrs["attn_layers"]:
                self.layer_idx_str = str(layer_idx) + f"_{suffix}"
                suffix += 1
            config.extra_attrs["attn_layers"][self.layer_idx_str] = weakref.ref(
                self)
            self.register_to_config = True

        config = config or ModelConfig()
        self.hidden_size = hidden_size
        self.num_heads = num_attention_heads
        self.head_dim = getattr(config.pretrained_config, 'head_dim', None)
        if not isinstance(self.head_dim, int):
            self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = max_position_embeddings
        self.pos_embd_params = pos_embd_params
        self.dense_bias = dense_bias
        self.q_scaling = q_scaling
        self.attn_output_gate = attn_output_gate

        if self.attn_output_gate:
            logger.info_once("using attn output gate!", key="attn_output_gate")

        # [Chunked Attention]
        # Chunked attention is applied to context requests only. Chunked attention will be
        # applied when this field is specified and mMaskType == CAUSAL.
        #
        # In chunked attention, we break context requests into chunks of a specified size. Tokens can only
        # attend to tokens in the same chunk. So, for example, if the chunk size is 3, we might have a mask
        # that looks like this:
        #
        # 1 0 0 0 0 0
        # 1 1 0 0 0 0
        # 1 1 1 0 0 0
        # 0 0 0 1 0 0
        # 0 0 0 1 1 0
        # 0 0 0 1 1 1
        self.attention_chunk_size = attention_chunk_size

        if dense_bias is None:
            self.dense_bias = bias

        # tensor parallel
        if mapping_with_cp is not None:
            logger.warning_once(
                "[Attention::__init__] Overriding mapping with CP detected.",
                key="attention_init_mapping_with_cp")
            self.mapping = mapping_with_cp
        else:
            self.mapping = config.mapping

        tp_size = self.mapping.tp_size
        pp_size = self.mapping.pp_size
        cp_size = self.mapping.cp_size
        dp_size = 1
        if self.mapping.enable_attention_dp:
            dp_size = tp_size
            tp_size = 1

        if self.mapping.cp_size > 1:
            assert self.mapping.has_cp_helix(
            ), f"CP type must be HELIX for Attention, but got {self.mapping.cp_config['cp_type']}."

        mapping = Mapping(
            world_size=dp_size * tp_size * pp_size * cp_size,
            tp_size=tp_size,
            pp_size=pp_size * dp_size,
            cp_size=cp_size,
            cp_config=self.mapping.cp_config,
            rank=self.mapping.rank,
            gpus_per_node=self.mapping.gpus_per_node,
            enable_attention_dp=self.mapping.enable_attention_dp,
        )
        self.tp_size = tp_size
        self.cp_size = cp_size
        self.tp_rank = mapping.tp_rank
        assert self.num_heads % (tp_size * cp_size) == 0
        self.num_heads = self.num_heads // tp_size
        self.num_heads_tp_cp = self.num_heads // cp_size
        self.num_key_value_heads = (self.num_key_value_heads + tp_size -
                                    1) // tp_size
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_key_value_heads * self.head_dim

        self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
        self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm

        qkv_shard_indices_mapping = {
            "q": (0, self.q_size * (2 if self.attn_output_gate else 1)),
            "k":
            (self.q_size * (2 if self.attn_output_gate else 1), self.kv_size),
            "v":
            (self.q_size * (2 if self.attn_output_gate else 1) + self.kv_size,
             self.kv_size),
        }

        self.qkv_proj = Linear(
            self.hidden_size,
            tp_size * self.q_size * (2 if self.attn_output_gate else 1) +
            2 * tp_size * self.kv_size,
            bias=bias,
            dtype=dtype,
            mapping=mapping,
            tensor_parallel_mode=TensorParallelMode.COLUMN,
            weights_loading_config=WeightsLoadingConfig(
                weight_mode=WeightMode.FUSED_QKV_LINEAR),
            quant_config=config.get_quant_config(),
            skip_create_weights_in_init=config.skip_create_weights_in_init,
            allreduce_strategy=config.allreduce_strategy,
            force_dynamic_quantization=config.force_dynamic_quantization,
            disable_deep_gemm=disable_deep_gemm,
            use_custom_cublas_mm=use_custom_cublas_mm,
            fused_weight_shard_indices_mapping=qkv_shard_indices_mapping,
            use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)

        self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
                                [self.hidden_size])

        # For Helix CP, combine TP and CP for the output projection so each
        # rank's o_proj input is num_heads_tp_cp * head_dim.
        mapping_o = Mapping(
            world_size=dp_size * tp_size * pp_size * cp_size,
            tp_size=tp_size * cp_size,
            pp_size=pp_size * dp_size,
            cp_size=1,
            rank=self.mapping.rank,
            gpus_per_node=self.mapping.gpus_per_node,
            enable_attention_dp=self.mapping.enable_attention_dp,
        )
        self.mapping_o = mapping_o

        self.o_proj = Linear(
            tp_size * self.q_size,
            self.hidden_size,
            bias=self.dense_bias,
            dtype=dtype,
            mapping=mapping_o,
            tensor_parallel_mode=TensorParallelMode.ROW,
            quant_config=config.get_quant_config(),
            skip_create_weights_in_init=config.skip_create_weights_in_init,
            lora=self.o_lora,
            reduce_output=reduce_output,
            allreduce_strategy=config.allreduce_strategy,
            force_dynamic_quantization=config.force_dynamic_quantization,
            disable_deep_gemm=disable_deep_gemm,
            use_custom_cublas_mm=use_custom_cublas_mm,
            use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)

        self.quant_config = config.get_quant_config()
        self.attn_backend = config.attn_backend

        # Resolve target_sparsity → threshold_scale_factor if needed
        sparse_attn_cfg = config.sparse_attention_config
        if (isinstance(sparse_attn_cfg, SkipSoftmaxAttentionConfig)
                and sparse_attn_cfg.target_sparsity is not None):
            hf_sparse = getattr(config.pretrained_config,
                                'sparse_attention_config', None)
            if not isinstance(hf_sparse, dict):
                raise ValueError(
                    "sparse_attention_config with target_sparsity requires formula "
                    "coefficients in the model's config.json "
                    "(sparse_attention_config.threshold_scale_factor.{prefill,decode}.{a,b}), "
                    "but sparse_attention_config was not found or was not dict type in config.json."
                )
            formula = hf_sparse.get('threshold_scale_factor', {})
            sparse_attn_cfg = sparse_attn_cfg.resolve_for_target_sparsity(
                formula)

        attn_cls = get_attention_backend(self.attn_backend,
                                         sparse_attn_config=sparse_attn_cfg)

        # These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used,
        # but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora
        # handles them as a single fused operation.
        self.splitted_qkv_lora = LoraLayer([
            LoraModuleType.ATTENTION_Q, LoraModuleType.ATTENTION_K,
            LoraModuleType.ATTENTION_V
        ], [self.q_size, self.kv_size, self.kv_size])
        self.fused_qkv_lora = LoraLayer([LoraModuleType.ATTENTION_QKV],
                                        [self.q_size + 2 * self.kv_size])

        # Whether to fuse RoPE into the attention OP.
        # If true, RoPE will be applied in self.attn.forward.
        # If false, RoPE will be applied in self.apply_rope.
        self.rope_fusion = rope_fusion

        if config.sparse_attention_config is not None:
            # Log sparse attention configuration once
            algo = config.sparse_attention_config.algorithm
            cfg_dump = config.sparse_attention_config.model_dump(
                exclude_none=True)
            logger.info_once(f"Using sparse attention: {algo} {cfg_dump}",
                             key="sparse_attention_config")

            if config.sparse_attention_config.algorithm == "rocket":
                logger.warning_once("disable rope_fusion for RocketKV.",
                                    key="disable_rope_fusion_for_rocketkv")
                self.rope_fusion = False

        if self.rope_fusion and not attn_cls.support_fused_rope():
            logger.warning_once(
                "rope_fusion is true but the attention backend does not support it. Will disable rope_fusion.",
                key="disable_rope_fusion_for_non_supported_backend")
            self.rope_fusion = False
        # If rope_fusion is not specified, enable if the attention backend supports it.
        if self.rope_fusion is None:
            self.rope_fusion = attn_cls.support_fused_rope()

        self.rotary_emb = None
        if not self.rope_fusion and self.pos_embd_params is not None:
            if self.pos_embd_params.type.is_mrope():
                self.rotary_emb = MRotaryEmbedding(
                    self.pos_embd_params.rope,
                    head_dim=self.head_dim,
                    is_neox=self.pos_embd_params.is_neox,
                    mrope_section=self.pos_embd_params.mrope_section,
                    mrope_interleaved=self.pos_embd_params.mrope_interleaved)
            else:
                self.rotary_emb = RotaryEmbedding(
                    self.pos_embd_params.rope,
                    head_dim=self.head_dim,
                    is_neox=self.pos_embd_params.is_neox,
                )

        self.attn = create_attention(
            self.attn_backend,
            self.layer_idx,
            self.num_heads,
            self.head_dim,
            self.num_key_value_heads,
            pos_embd_params=self.pos_embd_params if self.rope_fusion else None,
            quant_config=self.quant_config,
            skip_create_weights_in_init=config.skip_create_weights_in_init,
            q_scaling=self.q_scaling,
            attention_chunk_size=self.attention_chunk_size,
            sparse_attention_config=sparse_attn_cfg,
        )

        self.support_fused_qkv = self.attn.support_fused_qkv()

        if not config.skip_create_weights_in_init:
            self.create_weights()

    def create_weights(self):
        # self.attn has no weights but has states that are related to quant_config,
        # which could be modified after __init__
        self.attn.update_quant_config(self.quant_config)

        self.o_proj.create_weights()
        self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
                                or self.o_proj.has_fp8_block_scales
                                or self.o_proj.has_fp8_rowwise
                                or self.o_proj.has_w4a8_nvfp4_fp8)

    def split_qkv(self, q, k=None, v=None):
        if k is None and v is None:
            q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        return q, k, v

    def convert_qkv(self, q, k, v):
        if k is None and v is None and not self.support_fused_qkv:
            q, k, v = self.split_qkv(q)
        elif k is not None and v is not None and self.support_fused_qkv:
            qkv = torch.concat([q, k, v], dim=-1)
            q, k, v = qkv, None, None
        return q, k, v

    def _use_quantize_output(self):
        # If o_proj can't consume, then no need to quantize the output to nvfp4
        if hasattr(self.attn, 'has_nvfp4'
                   ) and self.attn.has_nvfp4 and not self.o_proj.has_nvfp4:
            return False
        # If no quant is applied, no need to quantize the output
        if self.quant_config is not None and not self.quant_config.layer_quant_mode.has_any_quant(
                exclude_kv_cache=True):
            return False

        has_awq_pre_quant_scale = hasattr(
            self.o_proj,
            'pre_quant_scale') and self.o_proj.pre_quant_scale is not None

        return self.has_quant_scale and not self.attn_output_gate and not has_awq_pre_quant_scale

    def create_output(self, q: torch.Tensor, attn_metadata: AttentionMetadata,
                      mask_type: str):
        # Attention is treated as mixed request by default.
        return self.attn.create_output(
            q,
            is_quantize_output=self._use_quantize_output(),
            metadata=attn_metadata,
            attention_mask=mask_type,
            is_gen_only=False)

    def _helix_post_process(self, partial_o: torch.Tensor,
                            softmax_stats: torch.Tensor) -> torch.Tensor:
        """Helix CP post-processing: all-to-all exchange and combine partial
        attention outputs across CP ranks."""
        return _helix_post_process(partial_o, softmax_stats, self.mapping,
                                   self.num_heads_tp_cp, self.head_dim)

    def _attn_impl(
        self,
        q: torch.Tensor,
        k: Optional[torch.Tensor],
        v: Optional[torch.Tensor],
        attn_metadata: AttentionMetadata,
        attention_mask: AttentionMask,
        mrope_rotary_cos_sin: Optional[torch.Tensor],
        mrope_position_deltas: Optional[torch.Tensor],
        attention_window_size: Optional[int],
        attention_mask_data: Optional[torch.Tensor],
        output: Optional[torch.Tensor] = None,
        output_sf: Optional[torch.Tensor] = None,
        attention_sinks: Optional[torch.Tensor] = None,
        has_lora: bool = False,
    ):
        num_tokens = attn_metadata.num_tokens

        q = q[:num_tokens, :]
        if k is not None:
            k = k[:num_tokens, :]
        if v is not None:
            v = v[:num_tokens, :]

        mrope_config = None
        if mrope_rotary_cos_sin is not None or mrope_position_deltas is not None:
            mrope_config = dict()
            if mrope_rotary_cos_sin is not None:
                mrope_config["mrope_rotary_cos_sin"] = mrope_rotary_cos_sin
            if mrope_position_deltas is not None:
                mrope_config["mrope_position_deltas"] = mrope_position_deltas

        # Helix CP generation path: get partial outputs with softmax stats,
        # then exchange and combine across CP ranks.
        # NOTE: The helix post-process combine step works on unquantized
        # (BF16/FP16) partial outputs and softmax stats from each rank.
        # We intentionally skip passing out_scale/out_scale_sf to FMHA here
        # so it produces BF16 output. After combining, the downstream o_proj
        # linear layer handles quantization (FP8/NVFP4) in its apply() method.
        if self.mapping.has_cp_helix() and attn_metadata.num_contexts == 0:
            assert output is None, (
                "Helix produces BF16 partial outputs which may not match a pre-allocated FP8/NVFP4 buffer for torch.compile inplace output."
            )
            softmax_stats = torch.empty((num_tokens, self.num_heads, 2),
                                        device=q.device,
                                        dtype=torch.float32)
            attn_output = self.attn.forward(
                q,
                k,
                v,
                attn_metadata,
                attention_mask=attention_mask,
                mrope_config=mrope_config,
                attention_window_size=attention_window_size,
                attention_mask_data=attention_mask_data,
                softmax_stats_tensor=softmax_stats,
                attention_sinks=attention_sinks)
            if isinstance(attn_output, tuple):
                attn_output = attn_output[0]
            attn_output = self._helix_post_process(attn_output, softmax_stats)
            return attn_output, None

        out_scale = None
        out_scale_sf = None
        # Don't set out_scale if o_proj has pre_quant_scale - this prevents FP8/FP4 output
        # and keeps attention output in BF16 for better precision when applying pre_quant_scale
        # Also don't set out_scale if LoRA is active - LoRA grouped_gemm doesn't support FP8
        if self._use_quantize_output() and not has_lora:
            out_scale = self.o_proj.inv_input_scale
            out_scale_sf = self.o_proj.input_scale

        kv_scales_sf = None
        kv_scales_sf_inv = None
        if self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp4_kv_cache(
        ):
            kv_scales_sf = self.qkv_proj.kv_scales
            kv_scales_sf_inv = self.qkv_proj.inv_kv_scales

        attn_output = self.attn.forward(
            q,
            k,
            v,
            attn_metadata,
            out_scale=out_scale,
            out_scale_sf=out_scale_sf,
            kv_scales_sf=kv_scales_sf,
            kv_scales_sf_inv=kv_scales_sf_inv,
            attention_mask=attention_mask,
            mrope_config=mrope_config,
            attention_window_size=attention_window_size,
            attention_mask_data=attention_mask_data,
            output=output[:num_tokens, :] if output is not None else None,
            output_sf=output_sf,
            attention_sinks=attention_sinks)
        if isinstance(attn_output, tuple):
            assert len(
                attn_output
            ) == 2, "attn_output should be a tuple of (output, output_sf)"
            return attn_output[0], attn_output[1]
        return attn_output, None

    def forward_impl(
        self,
        q: torch.Tensor,
        k: Optional[torch.Tensor],
        v: Optional[torch.Tensor],
        attn_metadata: AttentionMetadata,
        attention_mask: AttentionMask,
        attention_window_size: Optional[int],
        attention_mask_data: Optional[torch.Tensor],
        mrope_config: Optional[dict],
        attention_sinks: Optional[torch.Tensor] = None,
        has_lora: bool = False,
    ):
        mrope_rotary_cos_sin = None
        mrope_position_deltas = None
        if mrope_config is not None:
            if "mrope_rotary_cos_sin" in mrope_config:
                mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
            if "mrope_position_deltas" in mrope_config:
                mrope_position_deltas = mrope_config["mrope_position_deltas"]

        # Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
        # Only enable custom inplace op when torch compiling.
        use_custom_inplace_op = (self.register_to_config
                                 and (self.attn_backend == "TRTLLM"
                                      or self.attn_backend == "FLASHINFER")
                                 and is_torch_compiling())

        if use_custom_inplace_op:
            outputs = create_attn_outputs(q, attention_mask, self.layer_idx_str)
            assert len(outputs) == 1 or len(outputs) == 2
            output = outputs[0]
            output_sf = outputs[1] if len(outputs) == 2 else None
            attn_custom_op_inplace(
                q,
                k,
                v,
                attention_mask,
                mrope_rotary_cos_sin,
                mrope_position_deltas,
                attention_window_size,
                attention_mask_data,
                attention_sinks,
                self.layer_idx_str,
                output,
                output_sf,
            )
        else:
            output, output_sf = self._attn_impl(q,
                                                k,
                                                v,
                                                attn_metadata,
                                                attention_mask,
                                                mrope_rotary_cos_sin,
                                                mrope_position_deltas,
                                                attention_window_size,
                                                attention_mask_data,
                                                attention_sinks=attention_sinks,
                                                has_lora=has_lora)
        if output_sf is not None:
            output = Fp4QuantizedTensor(output, output_sf)

        return output

    def forward(
        self,
        position_ids: Optional[torch.IntTensor],
        hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
        attn_metadata: AttentionMetadata,
        attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
        mrope_config: Optional[dict] = None,
        all_reduce_params: Optional[AllReduceParams] = None,
        lora_params: Optional[dict] = None,
        attention_window_size: Optional[int] = None,
        attention_mask_data: Optional[torch.Tensor] = None,
        attention_sinks: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        """
        Forward pass for the Attention module.

        Args:
            position_ids (Optional[torch.IntTensor]): The position IDs.
            hidden_states (torch.Tensor): The hidden states.
            attn_metadata (AttentionMetadata): The attention metadata.
            attention_mask (AttentionMask): The attention mask type.
            mrope_config (Optional[dict]): The MROPE configuration.
            all_reduce_params (Optional[AllReduceParams]): The all reduce parameters.
            lora_params (Optional[dict]): The LoRA parameters.
            attention_window_size (Optional[int]): The attention window size.
            attention_mask_data (Optional[torch.Tensor]): The attention mask data.
        Returns:
            torch.Tensor: The output tensor.
        """
        hidden_states = _helix_cp_allgather_input(hidden_states, attn_metadata,
                                                  self.mapping, self.layer_idx)

        qkv = self.qkv_proj(hidden_states)

        if bool(lora_params):
            qkv_lora = self.splitted_qkv_lora(hidden_states, lora_params,
                                              self.layer_idx)
            if qkv_lora is not None:
                qkv = qkv + qkv_lora

            qkv_lora = self.fused_qkv_lora(hidden_states, lora_params,
                                           self.layer_idx)
            if qkv_lora is not None:
                qkv = qkv + qkv_lora

        if self.attn_output_gate:
            q_gate, k, v = qkv.split(
                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
            orig_shape = q_gate.shape[:-1]
            # Single line: view -> chunk -> reshape both q and gate
            q, gate = [
                t.reshape(*orig_shape, -1) for t in torch.chunk(
                    q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1)
            ]
        else:
            q, k, v = qkv, None, None

        # For dynamic tree spec decoding with Python RoPE, adjust position_ids
        # to use tree offsets (same as C++ kernel: past_seq_len + offset).
        if (not self.rope_fusion
                and getattr(attn_metadata, 'is_spec_dec_dynamic_tree', False)
                and getattr(attn_metadata, 'use_spec_decoding', False)
                and getattr(attn_metadata, 'spec_decoding_position_offsets',
                            None) is not None
                and attn_metadata.spec_decoding_position_offsets.dim() ==
                1  # 1D layout ⇒ dynamic tree
                and position_ids is not None):
            position_ids = self._adjust_position_ids_for_spec_dec(
                position_ids, attn_metadata)

        q, k, v = self.apply_rope(q, k, v, position_ids)
        q, k, v = self.convert_qkv(q, k, v)

        if attention_sinks is not None:
            assert self.attn_backend == "TRTLLM", "Attention sinks are only supported for TRTLLM backend."

        attn_output = self.forward_impl(q,
                                        k,
                                        v,
                                        attn_metadata,
                                        attention_mask,
                                        attention_window_size,
                                        attention_mask_data,
                                        mrope_config=mrope_config,
                                        attention_sinks=attention_sinks,
                                        has_lora=bool(lora_params))

        if self.attn_output_gate:
            gate = torch.sigmoid(gate)
            attn_output = attn_output * gate

        attn_output = _helix_cp_output_projection(self.o_proj, attn_output,
                                                  attn_metadata,
                                                  all_reduce_params,
                                                  self.mapping, self.mapping_o,
                                                  self.layer_idx, lora_params)
        return attn_output

    def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
                   v: Optional[torch.Tensor], position_ids: torch.Tensor):
        """
        Apply RoPE to the query and key.
        Depending on the implementation, q, k, v could be either fused (q, k, v = concat(q, k, v), None, None) or unfused (none of q, k, v is None).
        Before self.attn.forward, convert_qkv will be called to make sure that the format of (q, k, v) satisfies the requirement of self.attn.
        This method could be overridden in the subclass, in which extra functionalities such as q_norm/k_norm could be added.
        Args:
            q (torch.Tensor): The query tensor.
            k (Optional[torch.Tensor]): The key tensor.
            v (Optional[torch.Tensor]): The value tensor.
            position_ids (torch.Tensor): The position IDs of each token for RoPE.
        Returns:
            tuple: A tuple of (q, k, v).
        """
        # If RoPE is fused into the attention OP, do not apply RoPE here.
        if not self.rope_fusion and position_ids is not None:
            q, k, v = self.split_qkv(q, k, v)
            q, k = self.rotary_emb(position_ids, [q, k])
        return q, k, v

    def _adjust_position_ids_for_spec_dec(self, position_ids, attn_metadata):
        """Replicate C++ kernel's rotary_pos = past_seq_len + offset."""
        num_contexts = attn_metadata.num_contexts
        num_gens = attn_metadata.num_seqs - num_contexts
        if num_gens <= 0:
            return position_ids
        gen_len = int(attn_metadata.seq_lens[num_contexts])
        base_pos = attn_metadata.kv_lens_cuda[num_contexts:num_contexts +
                                              num_gens] - gen_len
        offsets = attn_metadata.spec_decoding_position_offsets[:num_gens *
                                                               gen_len].view(
                                                                   num_gens,
                                                                   gen_len)
        adjusted = (base_pos.unsqueeze(1) + offsets).reshape(-1)
        start = attn_metadata.num_ctx_tokens
        end = start + num_gens * gen_len
        position_ids[0, start:end] = adjusted
        return position_ids

    def apply_qk_norm(self, q, k):
        raise NotImplementedError(
            f"QK norm is not implemented for {self.__class__.__name__}. "
            "Please override the `apply_qk_norm` method in the subclass.")


@torch.library.custom_op("trtllm::mla_custom_op_inplace",
                         mutates_args=("output", ))
def mla_custom_op_inplace(
    hidden_states: torch.Tensor,
    position_ids: Optional[torch.Tensor],
    layer_idx: str,
    output: torch.Tensor,
    latent_cache_gen: Optional[torch.Tensor],
) -> None:
    metadata, mla_layer = extract_extra_attrs(layer_idx, "mla")
    mla_layer.forward_impl(position_ids,
                           hidden_states,
                           metadata,
                           output=output,
                           latent_cache_gen=latent_cache_gen)


@torch.library.custom_op("trtllm::mla_dsa_proj", mutates_args=())
def mla_dsa_proj(
    hidden_states: torch.Tensor,
    position_ids: Optional[torch.Tensor],
    layer_idx: str,
) -> List[torch.Tensor]:
    """Token-wise projections for DSA MLA (CUDA-graph-capturable).

    Runs kv_a_proj, layernorms, q_b_proj, and conditionally
    indexer.pre_indexer_proj (FP8 quantize, weight scaling).  Does NOT
    update the indexer k cache — that happens in Op 2 (mla_dsa_attn_inplace)
    because the scatter kernel accesses batch-specific metadata.

    Returns [q, compressed_kv, k_pe, latent_cache] when the short-MHA path
    handles all tokens, or [q, compressed_kv, k_pe, latent_cache, q_fp8,
    k_fp8, k_scale, weights] when the indexer runs.  Under torch compile,
    _should_use_short_mha returns False so the result is always length 8,
    keeping control flow straight-line for CUDA graph capture.
    """
    metadata, mla_layer = extract_extra_attrs(layer_idx, "mla")
    return mla_layer.forward_dsa_proj(position_ids, hidden_states, metadata)


@mla_dsa_proj.register_fake
def _mla_dsa_proj_fake(
    hidden_states: torch.Tensor,
    position_ids: Optional[torch.Tensor],
    layer_idx: str,
) -> List[torch.Tensor]:
    # Under torch compile _should_use_short_mha is False, so always 8 tensors.
    metadata, mla_layer = extract_extra_attrs(layer_idx, "mla")
    num_tokens = hidden_states.shape[0]
    indexer = mla_layer.mqa.indexer
    q = hidden_states.new_empty(
        [num_tokens, mla_layer.num_heads_tp * mla_layer.qk_head_dim])
    compressed_kv = hidden_states.new_empty(
        [num_tokens, mla_layer.kv_lora_rank])
    k_pe = hidden_states.new_empty([num_tokens, mla_layer.qk_rope_head_dim])
    latent_cache = hidden_states.new_empty(
        [num_tokens, mla_layer.kv_lora_rank + mla_layer.qk_rope_head_dim])
    # Indexer intermediates: q_fp8, k_fp8, k_scale, weights
    q_fp8 = hidden_states.new_empty(
        [num_tokens, indexer.n_heads, indexer.head_dim],
        dtype=torch.float8_e4m3fn)
    k_fp8 = hidden_states.new_empty([num_tokens, indexer.head_dim],
                                    dtype=torch.float8_e4m3fn)
    k_scale = hidden_states.new_empty([num_tokens, 1], dtype=torch.float32)
    weights = hidden_states.new_empty([num_tokens, indexer.n_heads],
                                      dtype=torch.float32)
    return [
        q, compressed_kv, k_pe, latent_cache, q_fp8, k_fp8, k_scale, weights
    ]


@torch.library.custom_op("trtllm::mla_dsa_attn_inplace",
                         mutates_args=("output", ))
def mla_dsa_attn_inplace(
    q: torch.Tensor,
    compressed_kv: torch.Tensor,
    k_pe: torch.Tensor,
    latent_cache: torch.Tensor,
    indexer_intermediates: List[torch.Tensor],
    position_ids: Optional[torch.Tensor],
    layer_idx: str,
    output: torch.Tensor,
) -> None:
    """Batch-structure-dependent attention dispatch for DSA MLA.

    indexer_intermediates is [q_fp8, k_fp8, k_scale, weights] when the
    indexer ran in Op 1, or [] when short-MHA handled all tokens.
    Runs sparse_attn_indexer then dispatches context/generation attention.
    This op is excluded from CUDA graph capture.
    """
    metadata, mla_layer = extract_extra_attrs(layer_idx, "mla")
    mla_layer.forward_dsa_attn(q, compressed_kv, k_pe, latent_cache,
                               indexer_intermediates, position_ids, metadata,
                               output)


def fp8_block_scaling_bmm_out(
    mat1: torch.Tensor,
    mat2_fp8: torch.Tensor,
    mat2_scale: torch.Tensor,
    out: torch.Tensor,
    mat2_dequant: Optional[torch.Tensor] = None,
    use_cute_dsl_blockscaling_bmm: bool = False,
) -> torch.Tensor:
    sm_version = get_sm_version()
    if sm_version == 90 or sm_version == 89:
        mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
            mat1)

        output = out.new_empty(out.shape, dtype=out.dtype, device=out.device)
        torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
                                                   mat1_scale, mat2_scale,
                                                   output)
        out.copy_(output)
    elif sm_version == 120:
        mat1_fp8, mat1_scale = fp8_utils.per_token_quant_and_transform(
            mat1, need_permute102=True)
        output = out.new_empty(out.shape, dtype=out.dtype, device=out.device)
        torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
                                                   mat1_scale, mat2_scale,
                                                   output)
        out.copy_(output)
    elif is_sm_100f(sm_version):
        if use_cute_dsl_blockscaling_bmm:
            mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
                mat1)
            torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(mat1_fp8, mat2_fp8,
                                                        mat1_scale, mat2_scale,
                                                        out)
            mat1_scale = None
        else:
            torch.bmm(mat1.transpose(0, 1),
                      mat2_dequant.transpose(1, 2),
                      out=out)
    else:
        raise NotImplementedError(f"SM{sm_version} is not supported")


class MLA(nn.Module):

    def __init__(
        self,
        *,
        hidden_size: int,
        num_attention_heads: int,
        num_key_value_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        predicted_tokens_per_seq: int,
        max_position_embeddings: int,
        bias: bool,
        aux_stream: Optional[torch.cuda.Stream] = None,
        pos_embd_params: Optional[PositionalEmbeddingParams] = None,
        layer_idx: Optional[int] = None,
        dtype: torch.dtype = None,
        dense_bias: Optional[bool] = None,
        config: Optional[ModelConfig] = None,
        mapping_with_cp: Optional[Mapping] = None,
        reduce_output: bool = True,
    ):
        """
        Initialize the MLA module.

        Args:
            hidden_size (int): The size of the hidden dimension.
            num_attention_heads (int): The number of attention heads.
            num_key_value_heads (int): The number of key value heads.
            qk_nope_head_dim (int): The dimension of the query and key without Rope.
            qk_rope_head_dim (int): The dimension of the Rope of query and key.
            v_head_dim (int): The dimension of the value.
            q_lora_rank (int): The dimension of the compressed query.
            kv_lora_rank (int): The dimension of the compressed key and value.
            predicted_tokens_per_seq (int): The number of predicted tokens per sequence.
            max_position_embeddings (int): The maximum position embeddings.
            bias (bool): Whether to use bias in the linear layers.
            aux_stream (Optional[torch.cuda.Stream]): The auxiliary CUDA stream for running operations in two parallel streams.
            pos_embd_params (PositionalEmbeddingParams): The positional embedding parameters.
            layer_idx (int): The layer index.
            dtype (torch.dtype): The data type.
            dense_bias (bool): Whether to use bias in the output projection layer.
            config (ModelConfig): The model configuration.
        """
        super().__init__()
        self.layer_idx = layer_idx
        self.layer_idx_str = str(layer_idx)
        self.dtype = dtype

        self.hidden_size = hidden_size
        self.num_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.predicted_tokens_per_seq = predicted_tokens_per_seq
        self.max_position_embeddings = max_position_embeddings
        self.pos_embd_params = pos_embd_params
        self.dense_bias = dense_bias
        if dense_bias is None:
            self.dense_bias = bias

        if self.q_lora_rank is None:
            self.q_lora_rank = hidden_size
            self.is_lite = True
        else:
            self.is_lite = False

        assert pos_embd_params is not None, "pos_embd_params must be provided in MLA"

        self.register_to_config = False
        if config is not None:
            if "mla_layers" not in config.extra_attrs:
                config.extra_attrs["mla_layers"] = {}
            config.extra_attrs["mla_layers"][self.layer_idx_str] = weakref.ref(
                self)
            self.register_to_config = True

        # Currently only DSA sparse attention is supported.
        if config is not None and config.sparse_attention_config is not None and config.sparse_attention_config.algorithm == "dsa":
            self.is_dsa = True
        else:
            self.is_dsa = False

        # tensor parallel
        config = config or ModelConfig()
        if mapping_with_cp is not None:
            logger.warning_once(
                "[MLA::__init__] Overriding mapping with CP detected.",
                key="mla_init_mapping_with_cp")
            self.mapping = mapping_with_cp
        else:
            self.mapping = config.mapping
        tp_size = self.mapping.tp_size
        pp_size = self.mapping.pp_size
        cp_size = self.mapping.cp_size
        dp_size = 1
        if self.mapping.enable_attention_dp:
            dp_size = tp_size
            tp_size = 1
        if self.mapping.has_cp_ulysses():
            raise NotImplementedError("MLA doesn't support CP Ulysses yet")
        if self.mapping.cp_size > 1:
            assert self.mapping.has_cp_helix(
            ), f"CP type must be HELIX for MLA, but got {self.mapping.cp_config['cp_type']}."

        mapping = Mapping(
            world_size=pp_size * dp_size * tp_size * cp_size,
            tp_size=tp_size,
            pp_size=pp_size * dp_size,
            cp_size=cp_size,
            cp_config=self.mapping.cp_config,
            rank=self.mapping.rank,
            gpus_per_node=self.mapping.gpus_per_node,
            enable_attention_dp=self.mapping.enable_attention_dp,
        )

        assert self.num_heads % (tp_size * cp_size) == 0
        self.num_heads_tp = self.num_heads // tp_size
        self.num_heads_tp_cp = self.num_heads_tp // cp_size
        self.num_key_value_heads_tp = (self.num_key_value_heads + tp_size -
                                       1) // tp_size

        rms_norm_eps = getattr(config.pretrained_config, "rms_norm_eps", 1e-6)
        quant_config = config.get_quant_config()
        self.quant_config = quant_config

        self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
        self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm

        if not self.is_lite:
            self.kv_a_proj_with_mqa = Linear(
                hidden_size,
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
                bias=bias,
                dtype=dtype,
                quant_config=quant_config,
                skip_create_weights_in_init=config.skip_create_weights_in_init,
                use_custom_cublas_mm=True,
                force_dynamic_quantization=config.force_dynamic_quantization,
                use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)

            self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank,
                                         eps=rms_norm_eps,
                                         dtype=dtype)

            self.q_b_proj = Linear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=bias,
                dtype=dtype,
                mapping=mapping,
                tensor_parallel_mode=TensorParallelMode.COLUMN,
                quant_config=quant_config,
                skip_create_weights_in_init=config.skip_create_weights_in_init,
                allreduce_strategy=config.allreduce_strategy,
                force_dynamic_quantization=config.force_dynamic_quantization,
                use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
        else:
            self.kv_a_proj_with_mqa = Linear(
                hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=bias,
                dtype=dtype,
                quant_config=quant_config,
                skip_create_weights_in_init=config.skip_create_weights_in_init,
                use_custom_cublas_mm=True,
                force_dynamic_quantization=config.force_dynamic_quantization,
                use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)

            self.q_proj = Linear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=bias,
                dtype=dtype,
                mapping=mapping,
                tensor_parallel_mode=TensorParallelMode.COLUMN,
                quant_config=quant_config,
                skip_create_weights_in_init=config.skip_create_weights_in_init,
                allreduce_strategy=config.allreduce_strategy,
                force_dynamic_quantization=config.force_dynamic_quantization,
                use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
            self.q_b_proj = self.q_proj

        self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
                                      dtype=dtype,
                                      eps=rms_norm_eps)

        self.kv_b_proj = Linear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=bias,
            dtype=dtype,
            mapping=mapping,
            tensor_parallel_mode=TensorParallelMode.COLUMN,
            quant_config=quant_config,
            skip_create_weights_in_init=config.skip_create_weights_in_init,
            allreduce_strategy=config.allreduce_strategy,
            force_dynamic_quantization=config.force_dynamic_quantization,
            use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
        # This parameter will view into self.kv_b_proj.weight after loading weights.
        # For dummy weight initialization, this parameter is initialized with empty tensor.
        # Used in forward_absorption only
        self.v_b_proj = nn.Parameter(
            torch.empty(
                (self.num_heads_tp_cp, self.v_head_dim, self.kv_lora_rank),
                dtype=dtype,
            ),
            requires_grad=False,
        )

        mapping_o = Mapping(
            world_size=pp_size * dp_size * tp_size * cp_size,
            tp_size=tp_size * cp_size,
            pp_size=pp_size * dp_size,
            cp_size=1,
            rank=self.mapping.rank,
            gpus_per_node=self.mapping.gpus_per_node,
            enable_attention_dp=self.mapping.enable_attention_dp,
        )
        self.mapping_o = mapping_o
        self.o_proj = Linear(
            self.num_key_value_heads * self.v_head_dim,
            self.hidden_size,
            bias=self.dense_bias,
            dtype=dtype,
            mapping=mapping_o,
            tensor_parallel_mode=TensorParallelMode.ROW,
            quant_config=quant_config,
            skip_create_weights_in_init=config.skip_create_weights_in_init,
            reduce_output=reduce_output,
            allreduce_strategy=config.allreduce_strategy,
            force_dynamic_quantization=config.force_dynamic_quantization,
            use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)

        def yarn_get_mscale(scale=1, mscale=1):
            if scale <= 1:
                return 1.0
            return 0.1 * mscale * math.log(scale) + 1.0

        mscale_all_dim = pos_embd_params.rope.mscale_all_dim
        scaling_factor = pos_embd_params.rope.scale
        mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
        q_scaling = 1.0 / (mscale * mscale)

        self.mqa = create_attention(
            config.attn_backend,
            self.layer_idx,
            self.num_heads_tp,
            head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
            num_kv_heads=1,
            pos_embd_params=pos_embd_params,
            quant_config=quant_config,
            q_scaling=q_scaling,
            is_mla_enable=True,
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            v_head_dim=self.kv_lora_rank,
            hidden_size=self.hidden_size,
            predicted_tokens_per_seq=self.predicted_tokens_per_seq,
            skip_create_weights_in_init=config.skip_create_weights_in_init,
            sparse_attention_config=config.sparse_attention_config,
            dtype=dtype,
            aux_stream=aux_stream,
        )

        self.softmax_scale = 1.0 / (math.sqrt(self.qk_head_dim) * q_scaling)

        self.aux_stream = aux_stream
        self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]

        self.rope_fusion = self.mqa.support_fused_rope()
        self.rotary_emb = None
        self.apply_rotary_emb = not self.rope_fusion
        if self.apply_rotary_emb:
            self.rotary_emb = RotaryEmbedding(
                pos_embd_params.rope,
                head_dim=self.qk_rope_head_dim,
                is_neox=pos_embd_params.is_neox,
            )

        # Short-sequence MHA optimization for DSA models:
        # For short prefill sequences, use MHA (kv_b_proj expansion + standard
        # attention) instead of the absorption path, which has overhead from
        # extra BMMs and larger head_dim (kv_lora_rank + qk_rope_head_dim).
        # Only active when rope_fusion is True (DSA with TrtllmAttention).
        _threshold_str = os.environ.get('TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD',
                                        '0')
        try:
            self.short_seq_mha_threshold = int(_threshold_str)
        except ValueError as err:
            raise ValueError(
                f"TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD must be an integer, "
                f"got '{_threshold_str}'") from err

        # MHA attention backend: used by non-DSA (standard MLA) and optionally
        # by DSA for the short-seq path (dense attention, no sparse config).
        _short_seq_mha = (self.is_dsa and self.short_seq_mha_threshold > 0
                          and not self.apply_rotary_emb)
        if not self.is_dsa or _short_seq_mha:
            self.mha = create_attention(
                config.attn_backend,
                self.layer_idx,
                self.num_heads_tp,
                head_dim=self.qk_head_dim,
                num_kv_heads=self.num_key_value_heads_tp,
                pos_embd_params=pos_embd_params,
                quant_config=quant_config,
                q_scaling=q_scaling,
                is_mla_enable=True,
                q_lora_rank=self.q_lora_rank,
                kv_lora_rank=self.kv_lora_rank,
                qk_nope_head_dim=self.qk_nope_head_dim,
                qk_rope_head_dim=self.qk_rope_head_dim,
                v_head_dim=self.v_head_dim,
                predicted_tokens_per_seq=self.predicted_tokens_per_seq,
                skip_create_weights_in_init=config.skip_create_weights_in_init,
                sparse_attention_config=(None if _short_seq_mha else
                                         config.sparse_attention_config),
            )
        else:
            self.mha = None

        self.llama_4_scaling = False
        if hasattr(config.pretrained_config, 'llama_4_scaling'):
            self.llama_4_scaling = True
            self.floor_scale = getattr(config.pretrained_config.llama_4_scaling,
                                       'original_max_position_embeddings', 8192)
            self.attn_scale = getattr(config.pretrained_config.llama_4_scaling,
                                      'beta', 0.1)

        if not config.skip_create_weights_in_init:
            self.create_weights()

    def create_weights(self):
        # self.mha/mqa has no weights but has states that are related to
        # quant_config, which could be modified after __init__.
        # self.mha is non-None for non-DSA models (standard MHA) and for DSA
        # models when the short-seq MHA optimization is active.
        if self.mha is not None:
            self.mha.update_quant_config(self.quant_config)
        self.mqa.update_quant_config(self.quant_config)

        # Although we use FP8 MLA for context/generation phase, the output is still in BF16
        self.out_scale = None

        # k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
        # which can be modified after __init__
        has_fp8_block_scales = (
            self.kv_b_proj.quant_config
            and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales())

        mla_weight_dtype = torch.float8_e4m3fn if has_fp8_block_scales else self.dtype
        self.k_b_proj_trans = nn.Parameter(
            torch.empty(
                (self.num_heads_tp, self.kv_lora_rank, self.qk_nope_head_dim),
                dtype=mla_weight_dtype,
            ),
            requires_grad=False,
        )

        self.k_b_proj_trans_dequant = None
        self.v_b_proj_dequant = None
        if has_fp8_block_scales:
            self.k_b_proj_trans_scale = nn.Parameter(
                torch.empty(
                    (
                        self.num_heads_tp,
                        self.kv_lora_rank // 128,
                        self.qk_nope_head_dim // 128,
                    ),
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            # This parameter will view into self.kv_b_proj.weight_scale after loading weights.
            # For dummy weight initialization, this parameter is initialized with empty tensor.
            self.v_b_proj_scale = nn.Parameter(
                torch.empty(
                    (
                        self.num_heads_tp_cp,
                        self.v_head_dim // 128,
                        self.kv_lora_rank // 128,
                    ),
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            if is_sm_100f() and not self.use_cute_dsl_blockscaling_bmm:
                assert self.dtype == torch.bfloat16
                self.k_b_proj_trans_dequant = nn.Parameter(
                    torch.empty(
                        (self.num_heads_tp, self.kv_lora_rank,
                         self.qk_nope_head_dim),
                        dtype=self.dtype,
                    ),
                    requires_grad=False,
                )
                self.v_b_proj_dequant = nn.Parameter(
                    torch.empty(
                        (self.num_heads_tp_cp, self.v_head_dim,
                         self.kv_lora_rank),
                        dtype=self.dtype,
                    ),
                    requires_grad=False,
                )
        else:
            self.k_b_proj_trans_scale = None
            self.v_b_proj_scale = None

    def apply_rope(
        self,
        q: torch.Tensor,
        k_pe: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        q = q.view(-1, self.num_heads_tp, self.qk_head_dim)
        q_pe = q[..., self.qk_nope_head_dim:].reshape(
            -1, self.num_heads_tp * self.qk_rope_head_dim)
        q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe])
        q[..., self.qk_nope_head_dim:] = q_pe.view(-1, self.num_heads_tp,
                                                   self.qk_rope_head_dim)
        return k_pe

    def _attn_forward_gen(self, attn_backend: AttentionBackend, q: torch.Tensor,
                          k: torch.Tensor, v: torch.Tensor,
                          position_ids: Optional[torch.Tensor],
                          attn_metadata: AttentionMetadata, **kwargs):
        if self.mapping.has_cp_helix():
            # partial_o: [num_tokens, num_heads_tp * kv_lora_rank]
            # softmax_stats: [num_tokens, num_heads_tp, 2]
            softmax_stats = torch.empty((q.shape[0], self.num_heads_tp, 2),
                                        device=q.device,
                                        dtype=torch.float32)
            partial_o = attn_backend.forward(
                q,
                k,
                v,
                attn_metadata,
                softmax_stats_tensor=softmax_stats,
                **kwargs,
            )
            kv_lora_rank = partial_o.shape[-1] // self.num_heads_tp
            assert self.kv_lora_rank == kv_lora_rank

            return _helix_post_process(partial_o, softmax_stats, self.mapping,
                                       self.num_heads_tp_cp, kv_lora_rank,
                                       self.aux_stream, self.ln_events)
        else:
            attn_output = attn_backend.forward(q, k, v, attn_metadata, **kwargs)
            return attn_output

    def create_output(self, hidden_states: torch.Tensor, num_contexts: int):
        num_tokens = hidden_states.shape[0]
        hidden_size = self.o_proj.in_features
        return hidden_states.new_empty([num_tokens, hidden_size],
                                       dtype=hidden_states.dtype)

    def _attention_scaling(self, q, position_ids):

        def _get_attn_scale(position_ids: torch.Tensor) -> torch.Tensor:
            positions = position_ids.view(-1)
            floor = torch.floor((positions + 1.0) / self.floor_scale)
            attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
            return attn_scale.unsqueeze(-1)

        attn_scale = _get_attn_scale(position_ids)
        q = (q * attn_scale).to(q.dtype)
        return q

    def forward_impl(self,
                     position_ids: Optional[torch.Tensor],
                     hidden_states: torch.Tensor,
                     attn_metadata: AttentionMetadata,
                     output: torch.Tensor,
                     latent_cache_gen: Optional[torch.Tensor] = None) -> None:
        """
        Forward pass for the MLA module. Writes result into output tensor in-place.

        Args:
            position_ids (Optional[torch.IntTensor]): The position IDs.
            hidden_states (torch.Tensor): The hidden states.
            attn_metadata (AttentionMetadata): The attention metadata.
            output (torch.Tensor): The output tensor to write results into.
            latent_cache_gen (Optional[torch.Tensor]): The latent cache used in generation.
        """
        # split q, k, v into context and gen batches
        num_contexts = attn_metadata.num_contexts
        num_generations = attn_metadata.num_generations
        num_ctx_tokens = attn_metadata.num_ctx_tokens
        num_tokens = attn_metadata.num_tokens

        hidden_states = hidden_states[:num_tokens, ...]
        if position_ids is not None:
            position_ids = position_ids[..., :num_tokens]

        if self.is_lite:
            compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
                [self.kv_lora_rank, self.qk_rope_head_dim], -1)
            compressed_kv = self.kv_a_layernorm(compressed_kv)
            q = hidden_states
        else:
            q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
                hidden_states).split([
                    self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
                ], -1)

            q, compressed_kv = maybe_execute_in_parallel(
                lambda: self.q_a_layernorm(q),
                lambda: self.kv_a_layernorm(compressed_kv),
                self.ln_events[0],
                self.ln_events[1],
                self.aux_stream,
            )

        q, latent_cache = maybe_execute_in_parallel(
            lambda: self.q_b_proj(q),
            lambda: torch.concat([compressed_kv, k_pe], dim=-1),
            self.ln_events[0],
            self.ln_events[1],
            self.aux_stream,
        )

        assert q.shape[
            0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"

        assert output is not None, "output must be provided"

        if num_contexts > 0:
            q_ctx = q[:num_ctx_tokens, ...]
            compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
            k_pe_ctx = k_pe[:num_ctx_tokens, ...]
            latent_cache_ctx = latent_cache[:num_ctx_tokens, ...]
            if self.apply_rotary_emb:
                assert position_ids is not None
                k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)

            if self.llama_4_scaling:
                q_ctx = self._attention_scaling(
                    q_ctx, position_ids[..., :num_ctx_tokens])

            self.forward_context(
                q_ctx,
                compressed_kv_ctx,
                k_pe_ctx,
                position_ids,
                attn_metadata,
                output[:num_ctx_tokens, :],
                latent_cache_ctx,
            )

        if num_generations > 0:
            q_gen = q[num_ctx_tokens:, ...]
            compressed_kv_gen = compressed_kv[num_ctx_tokens:, ...]
            k_pe_gen = k_pe[num_ctx_tokens:, ...]
            if latent_cache_gen is None:
                latent_cache_gen = latent_cache[num_ctx_tokens:, ...]
            if self.apply_rotary_emb:
                assert position_ids is not None
                k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)

            if self.llama_4_scaling:
                q_gen = self._attention_scaling(
                    q_gen, position_ids[..., num_ctx_tokens:])

            self.forward_absorption_generation(
                q_gen,
                compressed_kv_gen,
                k_pe_gen,
                attn_metadata,
                output[num_ctx_tokens:num_tokens, :],
                position_ids=position_ids,
                latent_cache=latent_cache_gen,
            )

    def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
                              hidden_states: torch.Tensor,
                              attn_metadata: AttentionMetadata,
                              output: torch.Tensor) -> None:
        """
        Forward pass for the MLA module with DSA (always in MQA mode).
        Writes result into output tensor in-place.

        Delegates to forward_dsa_proj (token-wise projections) followed by
        forward_dsa_attn (batch-dependent attention dispatch).

        Args:
            position_ids (Optional[torch.IntTensor]): The position IDs.
            hidden_states (torch.Tensor): The hidden states.
            attn_metadata (AttentionMetadata): The attention metadata.
            output (torch.Tensor): The output tensor to write results into.
        """
        proj_outputs = self.forward_dsa_proj(position_ids, hidden_states,
                                             attn_metadata)
        q, compressed_kv, k_pe, latent_cache = proj_outputs[:4]
        indexer_intermediates = proj_outputs[4:]
        self.forward_dsa_attn(q, compressed_kv, k_pe, latent_cache,
                              indexer_intermediates, position_ids,
                              attn_metadata, output)

    def forward_dsa_proj(
        self,
        position_ids: Optional[torch.Tensor],
        hidden_states: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> List[torch.Tensor]:
        """Token-wise projections for DSA MLA (CUDA-graph-capturable Op 1).

        Runs kv_a_proj, layernorms, q_b_proj, and conditionally
        indexer.pre_indexer_proj().

        IMPORTANT: This method must NOT slice tensors by num_tokens or
        access batch-specific metadata, so that all operations are
        unconditionally straight-line for CUDA graph capture.  Slicing
        to num_tokens happens in forward_dsa_attn (Op 2, outside graph).

        Returns [q, compressed_kv, k_pe, latent_cache] when short-MHA
        handles all tokens (eager only), or
        [q, compressed_kv, k_pe, latent_cache, q_fp8, k_fp8, k_scale,
        weights] when the indexer runs.  Under torch compile
        _should_use_short_mha returns False so it is always length 8.
        """
        assert self.mqa is not None, "DSA is only supported in MQA mode"

        q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
            [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], -1)

        q, compressed_kv = maybe_execute_in_parallel(
            lambda: self.q_a_layernorm(q),
            lambda: self.kv_a_layernorm(compressed_kv),
            self.ln_events[0],
            self.ln_events[1],
            self.aux_stream,
        )
        qr = q
        latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)

        q = self.q_b_proj(q)

        use_short_mha_for_ctx = self._should_use_short_mha(
            attn_metadata, position_ids)

        # Skip the indexer when the short MHA path handles all context
        # tokens and there are no generation tokens.
        if use_short_mha_for_ctx and attn_metadata.num_generations == 0:
            return [q, compressed_kv, k_pe, latent_cache]

        # pre_indexer_proj is the CUDA-graph-safe portion: pure token-wise
        # compute (cublas_mm, rope, FP8 quantize, weight scaling) with no
        # access to batch-specific metadata or the k cache.
        q_fp8, k_fp8, k_scale, weights = self.mqa.indexer.pre_indexer_proj(
            qr, hidden_states, position_ids)

        return [
            q, compressed_kv, k_pe, latent_cache, q_fp8, k_fp8, k_scale, weights
        ]

    def forward_dsa_attn(
        self,
        q: torch.Tensor,
        compressed_kv: torch.Tensor,
        k_pe: torch.Tensor,
        latent_cache: torch.Tensor,
        indexer_intermediates: List[torch.Tensor],
        position_ids: Optional[torch.Tensor],
        attn_metadata: AttentionMetadata,
        output: torch.Tensor,
    ) -> None:
        """Batch-structure-dependent attention for DSA MLA (Op 2, not graph-captured).

        indexer_intermediates is [q_fp8, k_fp8, k_scale, weights] when the
        indexer ran in Op 1, or [] when short-MHA handled all tokens.

        All num_tokens slicing happens here (not in Op 1) because
        num_tokens comes from batch-specific metadata and must not be
        baked into CUDA graph capture.
        """
        num_contexts = attn_metadata.num_contexts
        num_generations = attn_metadata.num_generations
        num_ctx_tokens = attn_metadata.num_ctx_tokens
        num_tokens = attn_metadata.num_tokens

        # Slice Op 1 outputs to actual num_tokens (Op 1 operates on the
        # full padded tensor for CUDA graph compatibility).
        q = q[:num_tokens, ...]
        compressed_kv = compressed_kv[:num_tokens, ...]
        k_pe = k_pe[:num_tokens, ...]
        latent_cache = latent_cache[:num_tokens, ...]
        if position_ids is not None:
            position_ids = position_ids[..., :num_tokens]

        use_short_mha_for_ctx = (num_contexts > 0
                                 and self._should_use_short_mha(
                                     attn_metadata, position_ids))

        if use_short_mha_for_ctx and num_generations == 0:
            topk_indices = None
        else:
            q_fp8, k_fp8, k_scale, weights = indexer_intermediates
            # Slice indexer intermediates to actual num_tokens (they were
            # computed on the full padded tensor in Op 1).
            q_fp8 = q_fp8[:num_tokens, ...]
            k_fp8 = k_fp8[:num_tokens, ...]
            k_scale = k_scale[:num_tokens, ...]
            weights = weights[:num_tokens, ...]
            topk_indices = self.mqa.indexer.sparse_attn_indexer(
                attn_metadata,
                q,  # only used for shape/device in buffer allocation
                q_fp8,
                k_fp8,
                k_scale,
                weights,
            )

        assert output is not None, "output must be provided"

        if num_contexts > 0:
            q_ctx = q[:num_ctx_tokens, ...]
            compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
            k_pe_ctx = k_pe[:num_ctx_tokens, ...]
            latent_cache_ctx = latent_cache[:num_ctx_tokens, ...]
            if self.apply_rotary_emb:
                assert position_ids is not None
                k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)

            self.forward_context_dsa(
                q_ctx,
                compressed_kv_ctx,
                k_pe_ctx,
                attn_metadata,
                output[:num_ctx_tokens, :],
                latent_cache_ctx,
                topk_indices=topk_indices[:num_ctx_tokens, :]
                if topk_indices is not None else None,
                position_ids=position_ids,
            )

        if num_generations > 0:
            q_gen = q[num_ctx_tokens:, ...]
            compressed_kv_gen = compressed_kv[num_ctx_tokens:, ...]
            k_pe_gen = k_pe[num_ctx_tokens:, ...]
            latent_cache_gen = latent_cache[num_ctx_tokens:, ...]
            if self.apply_rotary_emb:
                assert position_ids is not None
                k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)

            self.forward_generation_dsa(
                q_gen,
                compressed_kv_gen,
                k_pe_gen,
                attn_metadata,
                output[num_ctx_tokens:num_tokens, :],
                latent_cache_gen,
                topk_indices=topk_indices[num_ctx_tokens:num_tokens, :],
            )

    def forward_context_default(
        self,
        q: torch.Tensor,
        compressed_kv: torch.Tensor,
        k_pe: torch.Tensor,
        position_ids: torch.Tensor,
        attn_metadata: AttentionMetadata,
        output: torch.Tensor,
        latent_cache: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Dense MHA context path: expand KV via kv_b_proj and run attention.

        Used by non-DSA models and as the short-seq MHA fallback for DSA models.
        """
        kv = self.kv_b_proj(compressed_kv)
        k_nope, v = kv.split(
            [
                self.num_heads_tp * self.qk_nope_head_dim,
                self.num_heads_tp * self.v_head_dim
            ],
            -1,
        )

        k = torch.empty_like(q).view(-1, self.num_heads_tp, self.qk_head_dim)
        maybe_compiled_copy_(
            k[..., :self.qk_nope_head_dim],
            k_nope.view(-1, self.num_heads_tp, self.qk_nope_head_dim))
        # When rope_fusion=True (apply_rotary_emb=False), the rope portion
        # of k is left uninitialized here; the fused attention kernel
        # handles k_pe RoPE via latent_cache instead.
        if self.apply_rotary_emb:
            k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
                                                       self.qk_rope_head_dim)
        k = k.view(-1, self.num_heads_tp * self.qk_head_dim)

        attn_output = self.mha.forward(
            q,
            k,
            v,
            attn_metadata,
            attention_input_type=AttentionInputType.context_only,
            latent_cache=latent_cache,
            out_scale=self.out_scale,
            output=output,
        )

        return attn_output

    def _should_use_short_mha(self, attn_metadata: AttentionMetadata,
                              position_ids: Optional[torch.Tensor]) -> bool:
        """Check if the short-seq MHA optimization should be used for context.

        Uses max_ctx_kv_len (max total KV length per context sequence,
        including cached tokens) when available, to correctly account for
        chunked context where the full attention span exceeds the threshold
        even if the new token count is small.  Falls back to num_ctx_tokens
        (total new context tokens) when max_ctx_kv_len is not set.

        Disabled under torch compile so that the split DSA custom ops
        (mla_dsa_proj / mla_dsa_attn_inplace) have unconditionally
        straight-line control flow for CUDA graph capture.
        """
        if is_torch_compiling():
            return False
        if not (self.short_seq_mha_threshold > 0 and not self.apply_rotary_emb
                and self.mapping.cp_size == 1 and position_ids is not None):
            return False
        effective_len = getattr(attn_metadata, 'max_ctx_kv_len',
                                attn_metadata.num_ctx_tokens)
        return effective_len <= self.short_seq_mha_threshold

    def forward_context_dsa(
        self,
        q: torch.Tensor,
        compressed_kv: torch.Tensor,
        k_pe: torch.Tensor,
        attn_metadata: AttentionMetadata,
        output: torch.Tensor,
        latent_cache: Optional[torch.Tensor] = None,
        topk_indices: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Run context-phase attention for DSA models.

        Dispatches to the short-seq MHA path (forward_context) when the max
        per-sequence KV length (including cached tokens) is within the
        threshold, or falls through to the absorption/sparse MLA path
        otherwise.  forward_context() further dispatches to the appropriate
        handler (forward_context_default, forward_context_with_cached_kv, or
        forward_context_with_chunked_prefill) based on cached-KV state.

        Args:
            q: Query tensor, shape [num_ctx_tokens, num_heads * qk_head_dim].
            compressed_kv: Latent KV, shape [num_ctx_tokens, kv_lora_rank].
            k_pe: RoPE key portion, shape [num_ctx_tokens, qk_rope_head_dim].
            attn_metadata: Attention metadata for the current batch.
            output: Pre-allocated output tensor, written in-place.
            latent_cache: Concatenated [compressed_kv, k_pe] for KV cache.
            topk_indices: Sparse routing indices from the indexer (None when
                the short-seq MHA path is used).
            position_ids: Token position IDs (required for short-seq MHA).
        """
        # Short-sequence MHA: bypass absorption path for short prefills,
        # using kv_b_proj expansion + standard attention instead.
        # See __init__ comment for rationale. topk_indices is not used
        # because dense attention is faster than sparse routing at this scale.
        # forward_context() handles cached tokens by dispatching to
        # forward_context_with_cached_kv or forward_context_with_chunked_prefill.
        if self._should_use_short_mha(attn_metadata, position_ids):
            return self.forward_context(q, compressed_kv, k_pe, position_ids,
                                        attn_metadata, output, latent_cache)

        if get_sm_version() >= 100:
            return self.forward_absorption_context(q,
                                                   compressed_kv,
                                                   k_pe,
                                                   attn_metadata,
                                                   output,
                                                   latent_cache=latent_cache,
                                                   topk_indices=topk_indices)
        else:
            return self.forward_sparse_mla_kvcache_bf16(q,
                                                        latent_cache,
                                                        attn_metadata,
                                                        output,
                                                        topk_indices,
                                                        is_generation=False)

    def forward_generation_dsa(
        self,
        q: torch.Tensor,
        compressed_kv: torch.Tensor,
        k_pe: torch.Tensor,
        attn_metadata: AttentionMetadata,
        output: torch.Tensor,
        latent_cache: Optional[torch.Tensor] = None,
        topk_indices: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_sm_version() >= 100:
            return self.forward_absorption_generation(q,
                                                      compressed_kv,
                                                      k_pe,
                                                      attn_metadata,
                                                      output,
                                                      latent_cache=latent_cache,
                                                      topk_indices=topk_indices)
        else:
            return self.forward_sparse_mla_kvcache_bf16(q,
                                                        latent_cache,
                                                        attn_metadata,
                                                        output,
                                                        topk_indices,
                                                        is_generation=True)

    def forward_context_with_cached_kv(
        self,
        q: torch.Tensor,
        latent_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        output: torch.Tensor,
    ) -> torch.Tensor:
        assert latent_cache is not None
        trtllm_attention = cast(TrtllmAttention, self.mha)

        # apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
        trtllm_attention.mla_rope_append_paged_kv_assign_q(
            q, latent_cache, attn_metadata)

        # copy full_compressed_kv and full_k_pe from paged kv cache
        full_compressed_kv, full_k_pe = trtllm_attention.load_paged_kv_cache_for_mla(
            attn_metadata, q.dtype)
        assert full_compressed_kv.shape[
            0] == attn_metadata.num_ctx_cached_tokens + attn_metadata.num_ctx_tokens
        assert full_compressed_kv.shape[1] == self.kv_lora_rank
        assert full_k_pe.shape[
            0] == attn_metadata.num_ctx_cached_tokens + attn_metadata.num_ctx_tokens
        assert full_k_pe.shape[1] == self.qk_rope_head_dim
        assert full_compressed_kv.is_contiguous()
        assert full_k_pe.is_contiguous()

        # compute full_k_nope and full_v from full_compressed_kv
        full_kv = self.kv_b_proj(full_compressed_kv)
        full_k_nope, full_v = full_kv.split(
            [
                self.num_heads_tp * self.qk_nope_head_dim,
                self.num_heads_tp * self.v_head_dim
            ],
            -1,
        )

        full_k_nope = full_k_nope.view(-1, self.num_heads_tp,
                                       self.qk_nope_head_dim)
        full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
        full_k = maybe_compiled_cat(
            (full_k_nope, full_k_pe.expand(-1, self.num_heads_tp, -1)), dim=-1)
        full_k = full_k.view(-1, self.num_heads_tp * self.qk_head_dim)

        # release pytorch activation memory
        full_compressed_kv = None
        full_k_pe = None
        full_kv = None
        full_k_nope = None

        # latent_cache must be None to differentiate from normal context phase,
        # so that we can skip applying RoPE and appending KV cache inside attention op
        attn_output = self.mha.forward(
            q,
            full_k,
            full_v,
            attn_metadata,
            attention_input_type=AttentionInputType.context_only,
            latent_cache=None,
            out_scale=self.out_scale,
            output=output,
        )

        return attn_output

    def forward_context_with_chunked_prefill(
        self,
        q: torch.Tensor,
        compressed_kv: torch.Tensor,
        latent_cache: torch.
        Tensor,  # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
        attn_metadata: TrtllmAttentionMetadata,
        output: torch.Tensor,
    ) -> torch.Tensor:
        trtllm_attention = cast(TrtllmAttention, self.mha)
        # apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
        trtllm_attention.mla_rope_append_paged_kv_assign_q(
            q, latent_cache, attn_metadata)

        # determine the number of loop
        # currently we assume that the chunk size is the same as the max_num_tokens
        chunked_loop_num = attn_metadata.chunked_loop_num

        # [total_token_q, num_heads, 2] -> [total_token_q, num_heads] float2
        self.softmax_stats_tensor = torch.empty(
            (attn_metadata.num_ctx_tokens, self.num_heads_tp, 2),
            dtype=torch.float,
            device='cuda',
        )
        self.temp_softmax_stats_tensor = torch.empty(
            (attn_metadata.num_ctx_tokens, self.num_heads_tp, 2),
            dtype=torch.float,
            device='cuda',
        )

        attn_output = output
        temp_attn_output = q.new_empty(
            (q.size(0), self.num_heads_tp * self.v_head_dim), dtype=q.dtype)

        # use fake cached_cu_seq_len for chunked loop
        origin_kv_lens_cuda_runtime = attn_metadata.kv_lens_cuda_runtime
        origin_kv_lens_runtime = attn_metadata.kv_lens_runtime
        origin_ctx_total_kv_len = attn_metadata.host_total_kv_lens[0]

        for loop_idx in range(chunked_loop_num):
            # {b, chunked_unit_size, h, kv_lora_rank + qk_rope_head_dim} zero padded
            # fetch `loop_idx` chunk from kv cache
            temp_cu_chunked_seq_len = attn_metadata.cu_chunked_seq_len[loop_idx]
            total_ctx_chunked_tokens = attn_metadata.host_cu_chunked_seq_len[
                loop_idx, attn_metadata.num_contexts]
            chunked_global_offset = attn_metadata.chunked_global_offset[
                loop_idx]
            chunked_max_seq_len = attn_metadata.max_chunk_len_per_loop[loop_idx]
            chunked_compressed_kv, chunked_k_pe = trtllm_attention.load_chunked_kv_cache_for_mla(
                metadata=attn_metadata,
                num_ctx_cached_tokens=total_ctx_chunked_tokens,
                cu_chunked_seq_len=temp_cu_chunked_seq_len,
                chunked_global_offset=chunked_global_offset,
                chunked_max_seq_len=chunked_max_seq_len,
                out_dtype=q.dtype)

            # up proj to uncompressed kv
            # [tokens, 2, h, kv_dim], without rope_dim
            chunked_kv = self.kv_b_proj(chunked_compressed_kv)
            chunked_k_nope, chunked_v = chunked_kv.split(
                [
                    self.num_heads_tp * self.qk_nope_head_dim,
                    self.num_heads_tp * self.v_head_dim
                ],
                -1,
            )

            chunked_k_nope = chunked_k_nope.view(-1, self.num_heads_tp,
                                                 self.qk_nope_head_dim)
            chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
            chunked_k = maybe_compiled_cat(
                (chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads_tp,
                                                     -1)),
                dim=-1)
            chunked_k = chunked_k.view(-1, self.num_heads_tp * self.qk_head_dim)

            # release pytorch activation memory
            chunked_compressed_kv = None
            chunked_k_pe = None
            chunked_kv = None
            chunked_k_nope = None

            # copy chunked_seq_len to replace kv_lens_runtime
            attn_metadata.kv_lens_runtime = attn_metadata.host_chunked_seq_len[
                loop_idx]
            attn_metadata.kv_lens_cuda_runtime = attn_metadata.chunked_seq_len[
                loop_idx]
            attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens

            # do not apply mask for attention within loop
            # latent_cache must be None to differentiate from normal context phase,
            # so that we can skip applying RoPE and appending KV cache inside attention op
            temp_attn_output = self.mha.forward(
                q,
                chunked_k,
                chunked_v,
                attn_metadata,
                attention_input_type=AttentionInputType.context_only,
                latent_cache=None,
                out_scale=self.out_scale,
                attention_mask=PredefinedAttentionMask.FULL,
                softmax_stats_tensor=self.temp_softmax_stats_tensor,
                chunked_prefill_buffer_batch_size=attn_metadata.
                runtime_features.chunked_prefill_buffer_batch_size,
                output=temp_attn_output,
            )
            # merge attn result
            temp_merge_op = attn_metadata.merge_op_tensor[loop_idx]
            trtllm_attention.merge_attention_for_mla(
                attn_output, temp_attn_output, self.softmax_stats_tensor,
                self.temp_softmax_stats_tensor, temp_merge_op, attn_metadata)

        # deal with the uncached kv
        kv = self.kv_b_proj(compressed_kv)
        _, k_pe = latent_cache.view([
            -1, self.kv_lora_rank + self.qk_rope_head_dim
        ]).split([self.kv_lora_rank, self.qk_rope_head_dim], -1)
        # final round of attention

        k_nope, v = kv.split(
            [
                self.num_heads_tp * self.qk_nope_head_dim,
                self.num_heads_tp * self.v_head_dim
            ],
            -1,
        )

        k_nope = k_nope.view(-1, self.num_heads_tp, self.qk_nope_head_dim)
        k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
        k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads_tp, -1)),
                               dim=-1)
        k = k.view(-1, self.num_heads_tp * self.qk_head_dim)

        # copy q_lens to replace kv_lens_runtime
        attn_metadata.kv_lens_runtime = attn_metadata.prompt_lens_cpu_runtime
        attn_metadata.kv_lens_cuda_runtime = attn_metadata.prompt_lens_cuda_runtime
        attn_metadata.host_total_kv_lens[
            0] = attn_metadata.prompt_lens_cpu_runtime[:attn_metadata.
                                                       num_contexts].sum().item(
                                                       )

        # latent_cache must be None to differentiate from normal context phase,
        # so that we can skip applying RoPE and appending KV cache inside attention op
        temp_attn_output = self.mha.forward(
            q,
            k,
            v,
            attn_metadata,
            attention_input_type=AttentionInputType.context_only,
            latent_cache=None,
            out_scale=self.out_scale,
            softmax_stats_tensor=self.temp_softmax_stats_tensor,
            chunked_prefill_buffer_batch_size=attn_metadata.runtime_features.
            chunked_prefill_buffer_batch_size,
            output=temp_attn_output,
        )
        temp_merge_op = attn_metadata.merge_op_tensor[chunked_loop_num]
        trtllm_attention.merge_attention_for_mla(attn_output, temp_attn_output,
                                                 self.softmax_stats_tensor,
                                                 self.temp_softmax_stats_tensor,
                                                 temp_merge_op, attn_metadata)
        # copy back kv_lens_runtime and kv_lens_cuda_runtime
        attn_metadata.kv_lens_runtime = origin_kv_lens_runtime
        attn_metadata.kv_lens_cuda_runtime = origin_kv_lens_cuda_runtime
        attn_metadata.host_total_kv_lens[0] = origin_ctx_total_kv_len

        return attn_output

    @staticmethod
    @functools.cache
    def cached_warmup_forward_context_with_chunked_prefill(
            num_heads_tp, qk_nope_head_dim, qk_rope_head_dim, kv_lora_rank,
            v_head_dim, dtype, device):
        """Warmup torch.compile for cat operations with different tensor layouts.

        Tensors are marked with torch._dynamo.maybe_mark_dynamic(..., 0) on the
        num_tokens dimension, so for num_tokens != 1 a single warmup run is
        enough and the compiled kernel generalizes across varying num_tokens at
        runtime. num_tokens=1 still triggers recompile (torch.compile specializes
        for it), so it is warmed up separately. Do not use torch.compile with
        dynamic=True here because it completely ignores tensor layout/stride
        information, resulting in significantly degraded performance.
        """

        def warmup(num_tokens):
            chunked_k_nope = k_nope = torch.empty(
                num_tokens,
                num_heads_tp * (qk_nope_head_dim + v_head_dim),
                dtype=dtype,
                device=device)[:, :num_heads_tp * qk_nope_head_dim].view(
                    num_tokens, num_heads_tp, qk_nope_head_dim)
            chunked_k_pe = torch.empty(num_tokens,
                                       1,
                                       qk_rope_head_dim,
                                       dtype=dtype,
                                       device=device).expand(
                                           -1, num_heads_tp, -1)
            k_pe = torch.empty(num_tokens,
                               1,
                               kv_lora_rank + qk_rope_head_dim,
                               dtype=dtype,
                               device=device)[:, :, -qk_rope_head_dim:].expand(
                                   -1, num_heads_tp, -1)
            torch._dynamo.maybe_mark_dynamic(chunked_k_nope, 0)
            torch._dynamo.maybe_mark_dynamic(chunked_k_pe, 0)
            torch._dynamo.maybe_mark_dynamic(k_pe, 0)
            maybe_compiled_cat((chunked_k_nope, chunked_k_pe), dim=-1)
            maybe_compiled_cat((k_nope, k_pe), dim=-1)

        # With dim 0 (num_tokens) marked dynamic, one warmup suffices for all
        # num_tokens != 1 at runtime.
        warmup(2)

        # num_tokens=1 still triggers recompile; warm it separately.
        warmup(1)

    def forward_context(
        self,
        q: torch.Tensor,
        compressed_kv: torch.Tensor,
        k_pe: torch.Tensor,
        position_ids: torch.Tensor,
        attn_metadata: AttentionMetadata,
        output: torch.Tensor,
        latent_cache: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if isinstance(self.mha, TrtllmAttention):
            assert isinstance(attn_metadata, TrtllmAttentionMetadata)
            trtllm_attention = cast(TrtllmAttention, self.mha)
            if trtllm_attention.is_chunked_prefill_mla_context_for_warmup(
                    attn_metadata):
                self.cached_warmup_forward_context_with_chunked_prefill(
                    self.num_heads_tp, self.qk_nope_head_dim,
                    self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
                    q.dtype, q.device)
            if trtllm_attention.is_chunked_prefill_for_mla_context(
                    attn_metadata) and get_sm_version() >= 100:
                return self.forward_context_with_chunked_prefill(
                    q, compressed_kv, latent_cache, attn_metadata, output)
            elif trtllm_attention.has_cached_kv_for_mla_context(
                    attn_metadata
            ) or trtllm_attention.is_chunked_prefill_for_mla_context(
                    attn_metadata):
                return self.forward_context_with_cached_kv(
                    q, latent_cache, attn_metadata, output)
        return self.forward_context_default(q, compressed_kv, k_pe,
                                            position_ids, attn_metadata, output,
                                            latent_cache)

    def forward_absorption_generation(
        self,
        q: torch.Tensor,
        compressed_kv: torch.Tensor,
        k_pe: torch.Tensor,
        attn_metadata: AttentionMetadata,
        output: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
        latent_cache: Optional[torch.Tensor] = None,
        topk_indices: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        num_tokens = q.shape[0]
        q_nope, q_pe = q.view([-1, self.num_heads_tp, self.qk_head_dim]).split(
            [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        # fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank]
        # 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp
        num_seqs = attn_metadata.kv_lens_cuda_runtime.size(0)

        cu_q_seqlens = torch.empty(num_seqs + 1,
                                   dtype=torch.int32,
                                   device=q.device)
        cu_kv_seqlens = torch.empty(num_seqs + 1,
                                    dtype=torch.int32,
                                    device=q.device)
        fmha_scheduler_counter = torch.empty(1,
                                             dtype=torch.uint32,
                                             device=q.device)
        has_fp8_kv_cache = self.mqa.has_fp8_kv_cache if hasattr(
            self.mqa, 'has_fp8_kv_cache') else False

        mla_bmm1_scale = None
        mla_bmm2_scale = None
        quant_q_buffer = None
        if has_fp8_kv_cache:
            mla_bmm1_scale = torch.empty(2,
                                         dtype=torch.float32,
                                         device=q.device)
            mla_bmm2_scale = torch.empty(1,
                                         dtype=torch.float32,
                                         device=q.device)
            quant_q_buffer = torch.empty(
                num_tokens,
                self.num_heads_tp, (self.kv_lora_rank + self.qk_rope_head_dim),
                dtype=torch.uint8,
                device=q.device)

        fused_q = torch.empty(
            [
                num_tokens, self.num_heads_tp,
                (self.kv_lora_rank + self.qk_rope_head_dim)
            ],
            dtype=q.dtype,
            device=q.device,
        )

        rope_stream = self.aux_stream if not has_fp8_kv_cache else None
        if self.k_b_proj_trans.dtype == torch.bfloat16:
            # [num_heads, num_tokens, self.qk_nope_head_dim]
            q_nope_t = q_nope.transpose(0, 1)
            # [num_heads, num_tokens, self.kv_lora_rank]
            q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)

            # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim]
            # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank]
            # The output of bmm is written directly into fused_q
            maybe_execute_in_parallel(
                lambda: torch.ops.trtllm.bmm_out(
                    q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out),
                lambda: self.mqa.mla_rope_generation(
                    fused_q,
                    q_pe,
                    latent_cache,
                    attn_metadata,
                    cu_q_seqlens,
                    cu_kv_seqlens,
                    fmha_scheduler_counter,
                    mla_bmm1_scale,
                    mla_bmm2_scale,
                    quant_q_buffer,
                ),
                self.ln_events[0],
                self.ln_events[1],
                rope_stream,
            )

        elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn:
            # [num_heads, num_tokens, self.kv_lora_rank]
            q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)

            maybe_execute_in_parallel(
                lambda: fp8_block_scaling_bmm_out(
                    q_nope,
                    self.k_b_proj_trans,
                    self.k_b_proj_trans_scale,
                    q_nope_out,
                    self.k_b_proj_trans_dequant,
                    self.use_cute_dsl_blockscaling_bmm,
                ),
                lambda: self.mqa.mla_rope_generation(
                    fused_q,
                    q_pe,
                    latent_cache,
                    attn_metadata,
                    cu_q_seqlens,
                    cu_kv_seqlens,
                    fmha_scheduler_counter,
                    mla_bmm1_scale,
                    mla_bmm2_scale,
                    quant_q_buffer,
                ),
                self.ln_events[0],
                self.ln_events[1],
                rope_stream,
            )
        else:
            raise NotImplementedError(
                f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")

        fused_q = fused_q.view([
            num_tokens,
            self.num_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim)
        ])

        # Use generation_only for generation phase and context_only for context phase in DSA attention
        attention_input_type = AttentionInputType.generation_only

        attn_out_latent = self._attn_forward_gen(
            self.mqa,
            fused_q,
            None,
            None,
            position_ids,
            attn_metadata,
            attention_input_type=attention_input_type,
            out_scale=self.out_scale,
            latent_cache=latent_cache,  # kvcache and k_pe
            q_pe=q_pe,  # used by `invokeMLARopeGeneration`
            topk_indices=topk_indices,  # used by DSA attention
            is_generation=True,  # used by DSA attention
            cu_q_seqlens=cu_q_seqlens,  # used by `mlaGeneration`
            cu_kv_seqlens=cu_kv_seqlens,  # used by `mlaGeneration`
            fmha_scheduler_counter=
            fmha_scheduler_counter,  # used by `mlaGeneration`
            mla_bmm1_scale=mla_bmm1_scale,  # used by `mlaGeneration`
            mla_bmm2_scale=mla_bmm2_scale,  # used by `mlaGeneration`
            quant_q_buffer=quant_q_buffer,  # used by `mlaGeneration`
        )
        fused_q = None

        # note: if we do not have CP, then num_heads_tp_cp == num_heads_tp
        assert (attn_out_latent.shape[0] == q.shape[0]
                and attn_out_latent.shape[1]
                == self.num_heads_tp_cp * self.kv_lora_rank)

        # [seq, num_heads, kv_lora_rank]
        attn_out_latent = attn_out_latent.view(
            [-1, self.num_heads_tp_cp, self.kv_lora_rank])

        attn_output = output.view(
            [num_tokens, self.num_heads_tp_cp, self.v_head_dim])

        if self.v_b_proj.dtype == torch.bfloat16:
            # [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim]
            # -> [num_heads, seq, v_head_dim]
            torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1),
                                     self.v_b_proj.transpose(1, 2),
                                     attn_output.transpose(0, 1))
        elif self.v_b_proj.dtype == torch.float8_e4m3fn:
            fp8_block_scaling_bmm_out(
                attn_out_latent,
                self.v_b_proj,
                self.v_b_proj_scale,
                attn_output.transpose(0, 1),
                self.v_b_proj_dequant,
                self.use_cute_dsl_blockscaling_bmm,
            )
        else:
            raise NotImplementedError(
                f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")

        return output

    def forward_absorption_context(
        self,
        q: torch.Tensor,
        compressed_kv: torch.Tensor,
        k_pe: torch.Tensor,
        attn_metadata: AttentionMetadata,
        output: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
        latent_cache: Optional[torch.Tensor] = None,
        topk_indices: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        num_tokens = q.shape[0]
        q_nope, q_pe = q.view([-1, self.num_heads_tp, self.qk_head_dim]).split(
            [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        # fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank]
        # 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp
        fused_q = torch.empty(
            [
                num_tokens, self.num_heads_tp,
                (self.kv_lora_rank + self.qk_rope_head_dim)
            ],
            dtype=q.dtype,
            device=q.device,
        )

        if self.k_b_proj_trans.dtype == torch.bfloat16:
            # [num_heads, num_tokens, self.qk_nope_head_dim]
            q_nope_t = q_nope.transpose(0, 1)
            # [num_heads, num_tokens, self.kv_lora_rank]
            q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)

            # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim]
            # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank]
            # The output of bmm is written directly into fused_q
            torch.ops.trtllm.bmm_out(q_nope_t,
                                     self.k_b_proj_trans.transpose(1, 2),
                                     q_nope_out)
        elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn:
            # [num_heads, num_tokens, self.kv_lora_rank]
            q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)

            fp8_block_scaling_bmm_out(
                q_nope,
                self.k_b_proj_trans,
                self.k_b_proj_trans_scale,
                q_nope_out,
                self.k_b_proj_trans_dequant,
                self.use_cute_dsl_blockscaling_bmm,
            )
        else:
            raise NotImplementedError(
                f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")

        if self.apply_rotary_emb:
            fused_q[..., self.kv_lora_rank:] = q_pe
        fused_q = fused_q.view([
            num_tokens,
            self.num_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim)
        ])

        # Use generation_only for generation phase and context_only for context phase in DSA attention
        attention_input_type = AttentionInputType.context_only
        attn_out_latent = self._attn_forward_gen(
            self.mqa,
            fused_q,
            None,
            None,
            position_ids,
            attn_metadata,
            attention_input_type=attention_input_type,
            out_scale=self.out_scale,
            latent_cache=latent_cache,  # kvcache and k_pe
            q_pe=q_pe,  # used by `invokeMLARopeGeneration`
            topk_indices=topk_indices,  # used by DSA attention
            is_generation=False,  # used by DSA attention
        )
        fused_q = None

        # note: if we do not have CP, then num_heads_tp_cp == num_heads_tp
        assert (attn_out_latent.shape[0] == q.shape[0]
                and attn_out_latent.shape[1]
                == self.num_heads_tp_cp * self.kv_lora_rank)

        # [seq, num_heads, kv_lora_rank]
        attn_out_latent = attn_out_latent.view(
            [-1, self.num_heads_tp_cp, self.kv_lora_rank])

        attn_output = output.view(
            [num_tokens, self.num_heads_tp_cp, self.v_head_dim])

        if self.v_b_proj.dtype == torch.bfloat16:
            # [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim]
            # -> [num_heads, seq, v_head_dim]
            torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1),
                                     self.v_b_proj.transpose(1, 2),
                                     attn_output.transpose(0, 1))
        elif self.v_b_proj.dtype == torch.float8_e4m3fn:
            fp8_block_scaling_bmm_out(
                attn_out_latent,
                self.v_b_proj,
                self.v_b_proj_scale,
                attn_output.transpose(0, 1),
                self.v_b_proj_dequant,
                self.use_cute_dsl_blockscaling_bmm,
            )
        else:
            raise NotImplementedError(
                f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")

        return output

    @nvtx_range("forward_sparse_mla_kvcache_bf16")
    def forward_sparse_mla_kvcache_bf16(
        self,
        q: torch.Tensor,
        latent_cache: torch.Tensor,
        attn_metadata: DSAtrtllmAttentionMetadata,
        output: torch.Tensor,
        topk_indices: torch.Tensor,
        is_generation: bool = False,
    ) -> torch.Tensor:
        """
        Forward sparse MLA (DSA) for BF16 KV cache for both context and generation phases using FlashMLA kernels

        To form the input for FlashMLA kernel and adapt our KV cache manager, we need to:
        1. Append current tokens to paged cache and apply rope to q/k via mla_rope_append_paged_kv_assign_q
        2. Load full kv cache from paged memory (with k rope applied)
        3. Call FlashMLA sparse attention kernel for sparse prefill/decode
        """
        assert isinstance(attn_metadata, DSAtrtllmAttentionMetadata), \
            "DSA requires DSAtrtllmAttentionMetadata"
        # Append current tokens to paged cache and apply RoPE to q
        # This writes latent_cache to paged KV and modifies q in-place
        trtllm_attention = self.mqa
        with nvtx_range_debug(
                f"mla_rope_append_paged_kv_assign_q_is_generation={is_generation}"
        ):
            trtllm_attention.mla_rope_append_paged_kv_assign_q(
                q, latent_cache, attn_metadata, is_generation=is_generation)

        num_tokens = q.shape[0]
        q_nope, q_rope = q.view(-1, self.num_heads_tp, self.qk_head_dim).split(
            [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        q_nope_out = torch.empty(
            [num_tokens, self.num_heads_tp, (self.kv_lora_rank)],
            dtype=q.dtype,
            device=q.device,
        )

        if self.k_b_proj_trans.dtype == torch.bfloat16:
            # [num_heads, num_tokens, self.qk_nope_head_dim]
            q_nope_t = q_nope.transpose(0, 1)
            # [num_heads, num_tokens, self.kv_lora_rank]
            q_nope_out = q_nope_out.transpose(0, 1)

            # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim]
            # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank]
            # The output of bmm is written directly into fused_q
            torch.ops.trtllm.bmm_out(q_nope_t,
                                     self.k_b_proj_trans.transpose(1, 2),
                                     q_nope_out)
        elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn:
            # [num_heads, num_tokens, self.kv_lora_rank]
            q_nope_out = q_nope_out.transpose(0, 1)

            fp8_block_scaling_bmm_out(
                q_nope,
                self.k_b_proj_trans,
                self.k_b_proj_trans_scale,
                q_nope_out,
                self.k_b_proj_trans_dequant,
                self.use_cute_dsl_blockscaling_bmm,
            )
        else:
            raise NotImplementedError(
                f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")

        q_nope_out = q_nope_out.transpose(0, 1)
        q_concat = torch.cat([q_nope_out, q_rope], dim=-1)

        sm_version = get_sm_version()
        # FlashMLA sparse kernel (bf16) requires num_heads=128 on sm100 or multiple of 64 on sm90
        if sm_version >= 100:
            padding = 128
            assert self.num_heads_tp <= padding, (
                f"SM100 FlashMLA sparse kernel requires exactly {padding} heads, "
                f"got {self.num_heads_tp}. Padding from values > {padding} is not supported."
            )
        else:  # SM90
            padding = ((self.num_heads_tp + 63) // 64) * 64  # multiple of 64

        if self.num_heads_tp != padding:
            logger.warning_once(
                f"Padding num_heads from {self.num_heads_tp} to {padding} "
                f"due to FlashMLA sparse attention kernel requirement",
                key="sparse_mla_padding_warning")

            # Create padded tensor with zeros for extra heads
            q_padded = q_concat.new_empty(
                (num_tokens, padding, q_concat.shape[2]))
            q_padded[:, :self.num_heads_tp, :] = q_concat
            q_concat = q_padded

        # Convert indices and return all-layer KV pool
        # Note: underlying pool is layer-interleaved: [num_blocks, num_layers, kv_factor, tokens_per_block, num_kv_heads, head_dim]
        # to avoid reshape(copy) per-layer KV cache, we return all-layer KV pool w/ topk indices adjusted by stride_factor=num_layers*tokens_per_block
        topk_indices_pool, kv_cache_pool = transform_local_topk_and_prepare_pool_view(
            topk_indices,
            attn_metadata,
            layer_idx=self.layer_idx,
            is_generation=is_generation,
        )
        topk_indices_pool = topk_indices_pool.view(num_tokens, 1, -1)
        if flash_mla_sparse_fwd is not None:
            attn_out_latent = flash_mla_sparse_fwd(q_concat, kv_cache_pool,
                                                   topk_indices_pool,
                                                   self.softmax_scale)[0]
        else:
            raise RuntimeError(
                "flash_mla_sparse_fwd not available. Please ensure FlashMLA module is built."
            )

        # [seq, num_heads, kv_lora_rank], account for padding
        attn_out_latent = attn_out_latent[:, :self.num_heads_tp, :]
        attn_out_latent = attn_out_latent.view(
            [-1, self.num_heads_tp, self.kv_lora_rank])
        if self.num_heads_tp != padding:
            attn_out_latent = attn_out_latent.contiguous()

        assert (attn_out_latent.shape[0] == q.shape[0]
                and attn_out_latent.shape[1] == self.num_heads_tp)

        attn_output = output.view(
            [num_tokens, self.num_heads_tp, self.v_head_dim])

        if self.v_b_proj.dtype == torch.bfloat16:
            # [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim]
            # -> [num_heads, seq, v_head_dim]
            torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1),
                                     self.v_b_proj.transpose(1, 2),
                                     attn_output.transpose(0, 1))
        elif self.v_b_proj.dtype == torch.float8_e4m3fn:
            fp8_block_scaling_bmm_out(
                attn_out_latent,
                self.v_b_proj,
                self.v_b_proj_scale,
                attn_output.transpose(0, 1),
                self.v_b_proj_dequant,
                self.use_cute_dsl_blockscaling_bmm,
            )
        else:
            raise NotImplementedError(
                f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")
        return output

    def forward(
        self,
        position_ids: Optional[torch.Tensor],
        hidden_states: torch.Tensor,
        attn_metadata: AttentionMetadata,
        all_reduce_params: Optional[AllReduceParams] = None,
        latent_cache_gen: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        hidden_states = _helix_cp_allgather_input(hidden_states, attn_metadata,
                                                  self.mapping, self.layer_idx)

        attn_output = self.create_output(hidden_states,
                                         attn_metadata.num_contexts)
        if self.register_to_config:
            if self.is_dsa:
                proj_outputs = torch.ops.trtllm.mla_dsa_proj(
                    hidden_states, position_ids, self.layer_idx_str)
                q, compressed_kv, k_pe, latent_cache = proj_outputs[:4]
                indexer_intermediates = proj_outputs[4:]
                torch.ops.trtllm.mla_dsa_attn_inplace(
                    q, compressed_kv, k_pe, latent_cache, indexer_intermediates,
                    position_ids, self.layer_idx_str, attn_output)
            else:
                torch.ops.trtllm.mla_custom_op_inplace(hidden_states,
                                                       position_ids,
                                                       self.layer_idx_str,
                                                       attn_output,
                                                       latent_cache_gen)
        elif self.is_dsa:
            self.forward_impl_with_dsa(position_ids,
                                       hidden_states,
                                       attn_metadata,
                                       output=attn_output)
        else:
            self.forward_impl(position_ids,
                              hidden_states,
                              attn_metadata,
                              output=attn_output,
                              latent_cache_gen=latent_cache_gen)

        attn_output = _helix_cp_output_projection(self.o_proj, attn_output,
                                                  attn_metadata,
                                                  all_reduce_params,
                                                  self.mapping, self.mapping_o,
                                                  self.layer_idx)
        return attn_output

    def resmooth_parameters(self,
                            module_weight,
                            module_weight_scale,
                            recipe=(1, 128, 128)):
        weight, weight_scale = fp8_utils.resmooth_to_fp8_e8m0(
            module_weight, module_weight_scale)

        transfromed_scale = fp8_utils.transform_sf_into_required_layout(
            weight_scale,
            mn=weight.shape[1],
            k=weight.shape[2],
            recipe=recipe,
            num_groups=weight.shape[0],
            is_sfa=False)

        weight_param = torch.nn.Parameter(weight, requires_grad=False)
        scale_param = torch.nn.Parameter(transfromed_scale, requires_grad=False)

        return weight_param, scale_param

    def post_load_weights(self):
        has_fp8_block_scales = (
            self.kv_b_proj.quant_config
            and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales())
        is_sm120 = get_sm_version() == 120
        if is_sm120 and has_fp8_block_scales:
            self.k_b_proj_trans, self.k_b_proj_trans_scale = self.resmooth_parameters(
                self.k_b_proj_trans,
                self.k_b_proj_trans_scale,
                recipe=(1, 128, 128))

            self.v_b_proj, self.v_b_proj_scale = self.resmooth_parameters(
                self.v_b_proj, self.v_b_proj_scale, recipe=(1, 128, 128))
