FMHA Params V2#

Enums

enum class ContextAttentionMaskType#

Enumeration of context attention mask types.

Values:

enumerator PADDING = 0#

Mask the padded tokens.

enumerator CAUSAL#

Mask the padded tokens and all the tokens that come after in a sequence.

enumerator SLIDING_OR_CHUNKED_CAUSAL#

Causal mask + attend to the specific sliding window or chunk.

enumerator CUSTOM_MASK#

The custom mask input.

enum class AttentionInputLayout#

Enumeration of attention input layout types.

Values:

enumerator PACKED_QKV = 0#

QKV are packed into [B, S, 3, H, D] layout.

enumerator CONTIGUOUS_Q_KV#

Q has contiguous [Compact_S, H, D] layout, while KV has contiguous [Compact_S, 2, H, D] layout

enumerator Q_PAGED_KV#

Q has contiguous [B, S, H, D] layout, while paged KV layout are blocks of indices with shape of [B, 2, Blocks_per_Seq], and the indice indicates the block distance to the pool ptr in global memory

enumerator SEPARATE_Q_K_V#

Q has [B, S, H, D] layout, K has [B, S, H_kv, D] layout, V has [B, S, H_kv, Dv] layout

struct AlibiParams#
#include <fmhaParams_v2.h>

Parameters for ALiBi (Attention with Linear Biases) positional encoding.

Public Functions

AlibiParams() = default#
inline AlibiParams(int h, float scale_after_alibi = 1.f)#

Constructor for ALiBi parameters.

Parameters:
  • h – Number of attention heads

  • scale_after_alibi – Scaling factor to apply after ALiBi bias (default 1.0)

inline AlibiParams(
int h,
int s,
int tp_size,
int rank,
float scale_after_alibi = 1.f
)#

Constructor for ALiBi parameters with tensor parallelism support.

Parameters:
  • h – Number of attention heads per rank

  • s – Sequence length per rank

  • tp_size – Tensor parallelism size

  • rank – Current rank in tensor parallel group

  • scale_after_alibi – Scaling factor to apply after ALiBi bias (default 1.0)

Public Members

int h_pow_2 = {}#

Number of heads rounded down to nearest power of two.

float alibi_neg4_div_h = {}#

ALiBi slope computation: -4.0 / h_pow_2.

float scale_after_alibi = {}#

Scaling factor to apply after ALiBi bias

int head_idx_offset = 0#

Head index offset for tensor parallelism. Could be simplified to int rank derive the others as num_heads * rank, s * rank at runtime, but this makes assumptions about the layout downstream (e.g. downstream may only split across the head dimension, so s would be the full sequence)

int sequence_pos_offset = 0#

Sequence position offset for tensor parallelism.

Public Static Functions

static inline int round_down_to_power_two(int x)#

Rounds down an integer to the nearest power of two.

Parameters:

x – The input integer value

Returns:

The largest power of two less than or equal to x

struct KvBlockArray#
#include <fmhaParams_v2.h>

TMA (Tensor Memory Accelerator) descriptor structure.

An opaque 64-byte aligned structure used for TensorRT Memory Accelerator descriptors

Array structure for managing paged KV cache blocks

Public Types

using PtrType = int32_t#

Public Functions

KvBlockArray() = default#
inline KvBlockArray(
int32_t batchSize,
int32_t maxBlocksPerSeq,
int32_t tokensPerBlock,
int32_t bytesPerBlock,
void *poolPtr
)#

Constructor for KV block array.

Parameters:
  • batchSize – Current number of sequences

  • maxBlocksPerSeq – Maximum number of blocks per sequence

  • tokensPerBlock – Number of tokens per block (must be power of 2)

  • bytesPerBlock – Size of each KV cache block in bytes

  • poolPtr – Pointer to the beginning of the memory pool

Public Members

int32_t mMaxSeqs#

Current number of sequences.

int32_t mMaxBlocksPerSeq#

Max number of blocks per sequence.

int32_t mTokensPerBlock#

Number of tokens per block. It must be power of 2.

int32_t mTokensPerBlockLog2#

Exponent of number of tokens with base 2. E.g. for mTokensPerBlock 64, mTokensPerBlockLog2 equals to 6

int32_t mBytesPerBlock#

Size of KV cache blocks in bytes (H*D*T*sizeof(DataType))

Table maps logical block idx to the data pointer of k/v cache block pool. Shape [B, W, 2, M], where 2 is table for K and V, B is current number of sequences, W is beam width, M is Max number of blocks per sequence

void *mPoolPtr#

Pointer to beginning of pool.

PtrType *mBlockOffsets#

Pointer to block offsets.

struct FusedMultiheadAttentionParamsV2#
#include <fmhaParams_v2.h>

Parameters for fused multi-head attention version 2.

Public Members

void *qkv_ptr#

The packed QKV matrices.

void *q_ptr#

The separate Q matrix.

void *k_ptr#

The separate K matrix.

void *v_ptr#

The separate V matrix.

void *kv_ptr#

The separate KV matrix (contiguous KV)

KvBlockArray paged_kv_cache#

The separate paged kv cache.

void *packed_mask_ptr#

The mask to implement drop-out.

float *attention_sinks#

The attention sinks (per head)

void *o_ptr#

The O matrix (output)

void *softmax_stats_ptr#

The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max.

int64_t q_stride_in_bytes#

The stride between rows of Q.

int64_t k_stride_in_bytes#

The stride between rows of K.

int64_t v_stride_in_bytes#

The stride between rows of V.

int64_t packed_mask_stride_in_bytes#

The stride between matrices of packed mask.

int64_t o_stride_in_bytes#

The stride between rows of O.

int64_t softmax_stats_stride_in_bytes#

The stride between rows of softmax_stats_ptr.

cudaTmaDesc tma_desc_q#

TMA descriptors on device. Either q in packed qkv [B, S, 3, H, D] or separate q layout [B, S, H, D].

cudaTmaDesc tma_desc_k#

TMA descriptors for packed/contiguous/paged kv cache. Kv in packed qkv layout: [B, S, 3, H, D] Contiguous kv layout: [B, 2, H, S, D]. Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].

cudaTmaDesc tma_desc_v#

TMA descriptor for V.

cudaTmaDesc tma_desc_o#

TMA descriptor for O.

int blocks_per_tma_load#

TMA load of paged kv cache.

int blocks_per_tma_load_log2#

Log2 of blocks per TMA load.

int b#

The dimensions. In ordinary multi-head attention (MHA), there are equal number of QKV heads. b: batch size, h: number of query heads, h_kv: number of key/value heads, h_q_per_kv: number of query heads per key/value head, s: sequence length, s_kv: key/value sequence length, d: head dimension

int h#
int h_kv#
int h_q_per_kv#
int s#
int s_kv#
int d#
int dv = 0#

The dimension of V. If unset, dv = d.

int num_grouped_heads = 1#

The number of grouped heads in the seqlen dimension

int sliding_window_size = INT_MAX#

Sliding Window Attention. Only pay attention to [max(0, query_idx - sliding_window_size), query_idx].

int log2_chunked_attention_size = 0#

The chunked attention size in log2 (> 0 means that chunked attention is enabled)

uint32_t scale_bmm1#

The scaling factors for the kernel.

uint32_t softcapping_scale_bmm1#
uint32_t scale_softmax#
uint32_t scale_bmm2#
uint32_t *scale_bmm1_d#

The scaling factors in the device memory (required by TRT-LLM + FP8 FMHA)

uint32_t *scale_bmm2_d#

The scaling factors in the device memory (required by TRT-LLM + FP8 FMHA)

int *cu_q_seqlens#

Array of length b+1 holding prefix sum of actual q sequence lengths.

int *cu_kv_seqlens#

Array of length b+1 holding prefix sum of actual kv sequence lengths

int *cu_mask_rows#

Array of length b+1 holding prefix sum of actual mask sequence lengths. It might not be the same as cu_q_seqlens as the mask seqlens will be padded.

bool has_alibi = false#

If the kernel is using ALiBi or not.

AlibiParams alibi_params = {}#

ALiBi parameters.

uint32_t *tile_id_counter_ptr#

M tile id counter for dynamic scheduling.

uint32_t num_tiles#

Total number of tiles.

uint32_t num_tiles_per_head#

Number of tiles per head.

bool use_balanced_scheduling#

Whether to use balanced scheduling.

bool is_s_padded = false#

Is input/output padded.

struct FusedMultiheadAttentionParamsV2::SageAttention sage#

SAGE attention configuration.

struct SageAttention#
#include <fmhaParams_v2.h>

SAGE attention parameters.

struct Scales#
#include <fmhaParams_v2.h>

Per-block quantization scales.

Public Members

int max_nblock#

ceil(max_seqlen / block_size)

float *scales#

The scale of each block, layout: (B, H, max_nblock)

struct LaunchParams#
#include <fmhaParams_v2.h>

Flags to control kernel choice and launch parameters.

Public Members

bool ignore_b1opt = false#

Flags to control small batch kernel choice. true: never unroll

bool force_unroll = false#

true: always unroll

bool force_fp32_acc = false#

Use FP32 accumulation.

bool interleaved = false#

The C/32 format.

bool use_tma = false#

By default TMA is not used.

int total_q_seqlen = 0#

Total number of q tokens to set TMA descriptors.

int total_kv_seqlen = 0#

Total number of kv tokens to set TMA descriptors.

bool flash_attention = false#

If flash attention is used (only FP16)

bool warp_specialization = false#

If warp-specialized kernels are used (only SM90 HGMMA + TMA)

bool use_granular_tiling = false#

Granular tiling flash attention kernels

ContextAttentionMaskType attention_mask_type = ContextAttentionMaskType::PADDING#

Causal masking or sliding_or_chunked_causal masking or dense(padding) mask.

AttentionInputLayout attention_input_layout = AttentionInputLayout::PACKED_QKV#

The attention input layout.

bool enable_attn_logit_softcapping = false#

Enable attention logit softcapping (choose kernels with softcapping_scale_bmm1)

int multi_processor_count = 0#

Hardware properties to determine how to launch blocks.

int device_l2_cache_size = 0#

Device L2 cache size in bytes.