Context FMHA Runner#
-
class ContextFMHARunner#
Runner for context-phase fused multi-head attention (FMHA)
Manages and dispatches optimized FMHA kernels for the prefill/context phase of LLM inference. Supports various attention patterns (causal, sliding window) and memory layouts (packed QKV, separate Q/KV, paged KV cache).
Public Functions
- ContextFMHARunner(
- nvinfer1::DataType const dataType,
- int32_t batchSize,
- int32_t paddedSeqLen,
- int32_t numQHeads,
- int32_t numKvHeads,
- int32_t headSize,
- int32_t smVersion,
- AttentionInputLayout inputLayout,
- ContextAttentionMaskType maskType = ContextAttentionMaskType::CAUSAL,
- bool isSPadded = true
Construct context FMHA runner.
- Parameters:
dataType – Data type (e.g., FP16, BF16)
batchSize – Batch size
paddedSeqLen – Padded sequence length
numQHeads – Number of query heads
numKvHeads – Number of key-value heads
headSize – Attention head dimension
smVersion – CUDA compute capability (e.g., 89 for SM 8.9)
inputLayout – Input tensor layout
- Throws:
std::runtime_error – if a CUDA error occurs, or if the SM is not supported
-
ContextFMHARunner() = delete#
Deleted default constructor.
-
~ContextFMHARunner() noexcept = default#
Destructor.
-
size_t getWorkspaceSize()#
Get required workspace size in bytes.
- Returns:
Workspace size
-
void setupParams(FusedMultiheadAttentionParamsV2 ¶ms)#
Setup kernel parameters (excluding device pointers)
Configures FMHA parameters. Device pointers must be set by caller.
- Parameters:
params – FMHA parameter structure
- Throws:
std::runtime_error – if input layout or alpha type is unsupported
- void dispatchFMHAKernel(
- FusedMultiheadAttentionParamsV2 ¶ms,
- cudaStream_t const &stream
Dispatch FMHA kernel execution.
- Parameters:
params – FMHA parameters with device pointers set
stream – CUDA stream for kernel launch
- Throws:
std::runtime_error – if device pointers are invalid, or a CUDA error happens
Public Static Functions
- static bool canImplement(
- int32_t headSize,
- int32_t sm,
- nvinfer1::DataType dataType,
- AttentionInputLayout inputLayout,
- ContextAttentionMaskType maskType
Check if FMHA can be implemented for given head size/layout/mask combination.
- Parameters:
headSize – Attention head dimension
sm – CUDA compute capability
dataType – Data type
inputLayout – Input tensor layout
maskType – Attention mask type
- Returns:
True if implementation is available
- static bool loadContextFMHAKernels(
- int32_t sm,
- nvinfer1::DataType dataType
Load FMHA kernel cubins into device.
- Parameters:
sm – CUDA compute capability
dataType – Data type
- Throws:
std::runtime_error – if a CUDA driver error occurs
- Returns:
True if successful