FMHA Params V2#
Enums
-
enum class ContextAttentionMaskType#
Enumeration of context attention mask types.
Values:
-
enumerator PADDING = 0#
Mask the padded tokens.
-
enumerator CAUSAL#
Mask the padded tokens and all the tokens that come after in a sequence.
-
enumerator SLIDING_OR_CHUNKED_CAUSAL#
Causal mask + attend to the specific sliding window or chunk.
-
enumerator CUSTOM_MASK#
The custom mask input.
-
enumerator PADDING = 0#
-
enum class AttentionInputLayout#
Enumeration of attention input layout types.
Values:
-
enumerator PACKED_QKV = 0#
QKV are packed into [B, S, 3, H, D] layout.
-
enumerator CONTIGUOUS_Q_KV#
Q has contiguous [Compact_S, H, D] layout, while KV has contiguous [Compact_S, 2, H, D] layout
-
enumerator Q_PAGED_KV#
Q has contiguous [B, S, H, D] layout, while paged KV layout are blocks of indices with shape of [B, 2, Blocks_per_Seq], and the indice indicates the block distance to the pool ptr in global memory
-
enumerator SEPARATE_Q_K_V#
Q has [B, S, H, D] layout, K has [B, S, H_kv, D] layout, V has [B, S, H_kv, Dv] layout
-
enumerator PACKED_QKV = 0#
-
struct AlibiParams#
- #include <fmhaParams_v2.h>
Parameters for ALiBi (Attention with Linear Biases) positional encoding.
Public Functions
-
AlibiParams() = default#
-
inline AlibiParams(int h, float scale_after_alibi = 1.f)#
Constructor for ALiBi parameters.
- Parameters:
h – Number of attention heads
scale_after_alibi – Scaling factor to apply after ALiBi bias (default 1.0)
- inline AlibiParams(
- int h,
- int s,
- int tp_size,
- int rank,
- float scale_after_alibi = 1.f
Constructor for ALiBi parameters with tensor parallelism support.
- Parameters:
h – Number of attention heads per rank
s – Sequence length per rank
tp_size – Tensor parallelism size
rank – Current rank in tensor parallel group
scale_after_alibi – Scaling factor to apply after ALiBi bias (default 1.0)
Public Members
-
int h_pow_2 = {}#
Number of heads rounded down to nearest power of two.
-
float alibi_neg4_div_h = {}#
ALiBi slope computation: -4.0 / h_pow_2.
-
float scale_after_alibi = {}#
Scaling factor to apply after ALiBi bias
-
int head_idx_offset = 0#
Head index offset for tensor parallelism. Could be simplified to
int rankderive the others asnum_heads * rank, s * rankat runtime, but this makes assumptions about the layout downstream (e.g. downstream may only split across the head dimension, so s would be the full sequence)
-
int sequence_pos_offset = 0#
Sequence position offset for tensor parallelism.
Public Static Functions
-
static inline int round_down_to_power_two(int x)#
Rounds down an integer to the nearest power of two.
- Parameters:
x – The input integer value
- Returns:
The largest power of two less than or equal to x
-
AlibiParams() = default#
-
struct KvBlockArray#
- #include <fmhaParams_v2.h>
TMA (Tensor Memory Accelerator) descriptor structure.
An opaque 64-byte aligned structure used for TensorRT Memory Accelerator descriptors
Array structure for managing paged KV cache blocks
Public Types
-
using PtrType = int32_t#
Public Functions
-
KvBlockArray() = default#
- inline KvBlockArray(
- int32_t batchSize,
- int32_t maxBlocksPerSeq,
- int32_t tokensPerBlock,
- int32_t bytesPerBlock,
- void *poolPtr
Constructor for KV block array.
- Parameters:
batchSize – Current number of sequences
maxBlocksPerSeq – Maximum number of blocks per sequence
tokensPerBlock – Number of tokens per block (must be power of 2)
bytesPerBlock – Size of each KV cache block in bytes
poolPtr – Pointer to the beginning of the memory pool
Public Members
-
int32_t mMaxSeqs#
Current number of sequences.
-
int32_t mMaxBlocksPerSeq#
Max number of blocks per sequence.
-
int32_t mTokensPerBlock#
Number of tokens per block. It must be power of 2.
-
int32_t mTokensPerBlockLog2#
Exponent of number of tokens with base 2. E.g. for mTokensPerBlock 64, mTokensPerBlockLog2 equals to 6
-
int32_t mBytesPerBlock#
Size of KV cache blocks in bytes (H*D*T*sizeof(DataType))
Table maps logical block idx to the data pointer of k/v cache block pool. Shape [B, W, 2, M], where 2 is table for K and V, B is current number of sequences, W is beam width, M is Max number of blocks per sequence
-
void *mPoolPtr#
Pointer to beginning of pool.
-
using PtrType = int32_t#
-
struct FusedMultiheadAttentionParamsV2#
- #include <fmhaParams_v2.h>
Parameters for fused multi-head attention version 2.
Public Members
-
void *qkv_ptr#
The packed QKV matrices.
-
void *q_ptr#
The separate Q matrix.
-
void *k_ptr#
The separate K matrix.
-
void *v_ptr#
The separate V matrix.
-
void *kv_ptr#
The separate KV matrix (contiguous KV)
-
KvBlockArray paged_kv_cache#
The separate paged kv cache.
-
void *packed_mask_ptr#
The mask to implement drop-out.
-
float *attention_sinks#
The attention sinks (per head)
-
void *o_ptr#
The O matrix (output)
-
void *softmax_stats_ptr#
The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max.
-
int64_t q_stride_in_bytes#
The stride between rows of Q.
-
int64_t k_stride_in_bytes#
The stride between rows of K.
-
int64_t v_stride_in_bytes#
The stride between rows of V.
-
int64_t packed_mask_stride_in_bytes#
The stride between matrices of packed mask.
-
int64_t o_stride_in_bytes#
The stride between rows of O.
-
int64_t softmax_stats_stride_in_bytes#
The stride between rows of softmax_stats_ptr.
-
cudaTmaDesc tma_desc_q#
TMA descriptors on device. Either q in packed qkv [B, S, 3, H, D] or separate q layout [B, S, H, D].
-
cudaTmaDesc tma_desc_k#
TMA descriptors for packed/contiguous/paged kv cache. Kv in packed qkv layout: [B, S, 3, H, D] Contiguous kv layout: [B, 2, H, S, D]. Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
-
cudaTmaDesc tma_desc_v#
TMA descriptor for V.
-
cudaTmaDesc tma_desc_o#
TMA descriptor for O.
-
int blocks_per_tma_load#
TMA load of paged kv cache.
-
int blocks_per_tma_load_log2#
Log2 of blocks per TMA load.
-
int b#
The dimensions. In ordinary multi-head attention (MHA), there are equal number of QKV heads. b: batch size, h: number of query heads, h_kv: number of key/value heads, h_q_per_kv: number of query heads per key/value head, s: sequence length, s_kv: key/value sequence length, d: head dimension
-
int h#
-
int h_kv#
-
int h_q_per_kv#
-
int s#
-
int s_kv#
-
int d#
-
int dv = 0#
The dimension of V. If unset, dv = d.
-
int num_grouped_heads = 1#
The number of grouped heads in the seqlen dimension
-
int sliding_window_size = INT_MAX#
Sliding Window Attention. Only pay attention to [max(0, query_idx - sliding_window_size), query_idx].
-
int log2_chunked_attention_size = 0#
The chunked attention size in log2 (> 0 means that chunked attention is enabled)
-
uint32_t scale_bmm1#
The scaling factors for the kernel.
-
uint32_t softcapping_scale_bmm1#
-
uint32_t scale_softmax#
-
uint32_t scale_bmm2#
-
uint32_t *scale_bmm1_d#
The scaling factors in the device memory (required by TRT-LLM + FP8 FMHA)
-
uint32_t *scale_bmm2_d#
The scaling factors in the device memory (required by TRT-LLM + FP8 FMHA)
-
int *cu_q_seqlens#
Array of length b+1 holding prefix sum of actual q sequence lengths.
-
int *cu_kv_seqlens#
Array of length b+1 holding prefix sum of actual kv sequence lengths
-
int *cu_mask_rows#
Array of length b+1 holding prefix sum of actual mask sequence lengths. It might not be the same as cu_q_seqlens as the mask seqlens will be padded.
-
bool has_alibi = false#
If the kernel is using ALiBi or not.
-
AlibiParams alibi_params = {}#
ALiBi parameters.
-
uint32_t *tile_id_counter_ptr#
M tile id counter for dynamic scheduling.
-
uint32_t num_tiles#
Total number of tiles.
-
uint32_t num_tiles_per_head#
Number of tiles per head.
-
bool use_balanced_scheduling#
Whether to use balanced scheduling.
-
bool is_s_padded = false#
Is input/output padded.
-
struct FusedMultiheadAttentionParamsV2::SageAttention sage#
SAGE attention configuration.
-
struct SageAttention#
- #include <fmhaParams_v2.h>
SAGE attention parameters.
Public Members
-
struct FusedMultiheadAttentionParamsV2::SageAttention::Scales q#
-
struct FusedMultiheadAttentionParamsV2::SageAttention::Scales k#
-
struct FusedMultiheadAttentionParamsV2::SageAttention::Scales v#
Scales for Q, K, V.
-
struct Scales#
- #include <fmhaParams_v2.h>
Per-block quantization scales.
-
struct FusedMultiheadAttentionParamsV2::SageAttention::Scales q#
-
void *qkv_ptr#
-
struct LaunchParams#
- #include <fmhaParams_v2.h>
Flags to control kernel choice and launch parameters.
Public Members
-
bool ignore_b1opt = false#
Flags to control small batch kernel choice. true: never unroll
-
bool force_unroll = false#
true: always unroll
-
bool force_fp32_acc = false#
Use FP32 accumulation.
-
bool interleaved = false#
The C/32 format.
-
bool use_tma = false#
By default TMA is not used.
-
int total_q_seqlen = 0#
Total number of q tokens to set TMA descriptors.
-
int total_kv_seqlen = 0#
Total number of kv tokens to set TMA descriptors.
-
bool flash_attention = false#
If flash attention is used (only FP16)
-
bool warp_specialization = false#
If warp-specialized kernels are used (only SM90 HGMMA + TMA)
-
bool use_granular_tiling = false#
Granular tiling flash attention kernels
-
ContextAttentionMaskType attention_mask_type = ContextAttentionMaskType::PADDING#
Causal masking or sliding_or_chunked_causal masking or dense(padding) mask.
-
AttentionInputLayout attention_input_layout = AttentionInputLayout::PACKED_QKV#
The attention input layout.
-
bool enable_attn_logit_softcapping = false#
Enable attention logit softcapping (choose kernels with softcapping_scale_bmm1)
-
int multi_processor_count = 0#
Hardware properties to determine how to launch blocks.
-
int device_l2_cache_size = 0#
Device L2 cache size in bytes.
-
bool ignore_b1opt = false#