Cute Dsl FMHA Runner#

class CuteDslFMHARunner#

Unified runner for CuTe DSL compiled FMHA kernels (Blackwell SM100+).

Supports two execution modes via separate AOT-compiled kernel variants:

  1. LLM prefill/chunked-prefill: batched Q [B,S_q,H_q,D] + combined KV cache [B,2,H_kv,Cap,D] with causal masking and optional sliding window.

  2. ViT: packed varlen separate Q/K/V [total_S,H,D] with cu_seqlens [B+1] for ragged batching, bidirectional attention (no causal mask).

Each mode has its own kernel modules and run() overload.

Public Functions

CuteDslFMHARunner(
int32_t numQHeads,
int32_t numKVHeads,
int32_t headDim,
int32_t batchSize = 0,
int32_t seqLenQ = 0,
int32_t kvCacheCapacity = 0
)#
~CuteDslFMHARunner() = default#
CuteDslFMHARunner(CuteDslFMHARunner const&) = delete#
CuteDslFMHARunner &operator=(CuteDslFMHARunner const&) = delete#
void run(
void const *qPtr,
void const *kvPtr,
void *oPtr,
int32_t const *cuKVSeqLens,
cudaStream_t stream,
int32_t slidingWindowSize = INT_MAX
)#

LLM FMHA: batched Q + combined KV cache with causal masking.

Parameters:
  • qPtr – Query [B, S_q, H_q, D]

  • kvPtr – Combined KV cache [B, 2, H_kv, Cap, D]

  • oPtr – Output [B, S_q, H_q, D]

  • cuKVSeqLens – Cumulative KV sequence lengths [B+1]

  • stream – CUDA stream

  • slidingWindowSize – Sliding window size (INT_MAX = disabled)

void run(
void const *qPtr,
void const *kPtr,
void const *vPtr,
void *oPtr,
int32_t const *cuSeqLens,
int32_t totalSeqLen,
int32_t maxSeqLen,
int32_t batchSize,
cudaStream_t stream
)#

ViT FMHA: packed varlen separate Q/K/V, bidirectional.

Parameters:
  • qPtr – Query [total_S, H, D]

  • kPtr – Key [total_S, H, D]

  • vPtr – Value [total_S, H, D]

  • oPtr – Output [total_S, H, D]

  • cuSeqLens – Cumulative sequence lengths [B+1]

  • totalSeqLen – Sum of all sequence lengths

  • maxSeqLen – Longest individual sequence length

  • batchSize – Number of sequences

  • stream – CUDA stream

Public Static Functions

static bool canImplement(int32_t headSize, int32_t smVersion)#
static bool canImplementViT(int32_t headSize, int32_t smVersion)#
static bool loadLLMKernelModule()#
static void unloadLLMKernelModule()#
static bool loadViTKernelModule()#
static void unloadViTKernelModule()#