Cute Dsl FMHA Runner#
-
class CuteDslFMHARunner#
Unified runner for CuTe DSL compiled FMHA kernels (Blackwell SM100+).
Supports two execution modes via separate AOT-compiled kernel variants:
LLM prefill/chunked-prefill: batched Q [B,S_q,H_q,D] + combined KV cache [B,2,H_kv,Cap,D] with causal masking and optional sliding window.
ViT: packed varlen separate Q/K/V [total_S,H,D] with cu_seqlens [B+1] for ragged batching, bidirectional attention (no causal mask).
Each mode has its own kernel modules and run() overload.
Public Functions
- CuteDslFMHARunner(
- int32_t numQHeads,
- int32_t numKVHeads,
- int32_t headDim,
- int32_t batchSize = 0,
- int32_t seqLenQ = 0,
- int32_t kvCacheCapacity = 0
-
~CuteDslFMHARunner() = default#
-
CuteDslFMHARunner(CuteDslFMHARunner const&) = delete#
-
CuteDslFMHARunner &operator=(CuteDslFMHARunner const&) = delete#
- void run(
- void const *qPtr,
- void const *kvPtr,
- void *oPtr,
- int32_t const *cuKVSeqLens,
- cudaStream_t stream,
- int32_t slidingWindowSize = INT_MAX
LLM FMHA: batched Q + combined KV cache with causal masking.
- Parameters:
qPtr – Query [B, S_q, H_q, D]
kvPtr – Combined KV cache [B, 2, H_kv, Cap, D]
oPtr – Output [B, S_q, H_q, D]
cuKVSeqLens – Cumulative KV sequence lengths [B+1]
stream – CUDA stream
slidingWindowSize – Sliding window size (INT_MAX = disabled)
- void run(
- void const *qPtr,
- void const *kPtr,
- void const *vPtr,
- void *oPtr,
- int32_t const *cuSeqLens,
- int32_t totalSeqLen,
- int32_t maxSeqLen,
- int32_t batchSize,
- cudaStream_t stream
ViT FMHA: packed varlen separate Q/K/V, bidirectional.
- Parameters:
qPtr – Query [total_S, H, D]
kPtr – Key [total_S, H, D]
vPtr – Value [total_S, H, D]
oPtr – Output [total_S, H, D]
cuSeqLens – Cumulative sequence lengths [B+1]
totalSeqLen – Sum of all sequence lengths
maxSeqLen – Longest individual sequence length
batchSize – Number of sequences
stream – CUDA stream