Context FMHA Runner#
-
class ContextFMHARunner#
Runner for context-phase fused multi-head attention (FMHA)
Manages and dispatches optimized FMHA kernels for the prefill/context phase of LLM inference. Supports various attention patterns (causal, sliding window) and memory layouts (packed QKV, separate Q/KV, paged KV cache).
Public Functions
- ContextFMHARunner(
- nvinfer1::DataType const dataType,
- int32_t batchSize,
- int32_t paddedSeqLen,
- int32_t numQHeads,
- int32_t numKvHeads,
- int32_t headSize,
- int32_t smVersion,
- AttentionInputLayout inputLayout
Construct context FMHA runner.
- Parameters:
dataType – Data type (e.g., FP16, BF16)
batchSize – Batch size
paddedSeqLen – Padded sequence length
numQHeads – Number of query heads
numKvHeads – Number of key-value heads
headSize – Attention head dimension
smVersion – CUDA compute capability (e.g., 89 for SM 8.9)
inputLayout – Input tensor layout
-
ContextFMHARunner() = delete#
Deleted default constructor.
-
~ContextFMHARunner() = default#
Destructor.
-
size_t getWorkspaceSize()#
Get required workspace size in bytes.
- Returns:
Workspace size
-
void setupParams(FusedMultiheadAttentionParamsV2 ¶ms)#
Setup kernel parameters (excluding device pointers)
Configures FMHA parameters. Device pointers must be set by caller.
- Parameters:
params – FMHA parameter structure
- void dispatchFMHAKernel(
- FusedMultiheadAttentionParamsV2 ¶ms,
- cudaStream_t const &stream
Dispatch FMHA kernel execution.
- Parameters:
params – FMHA parameters with device pointers set
stream – CUDA stream for kernel launch
Public Static Functions
- static bool canImplement(
- int32_t headSize,
- int32_t sm,
- nvinfer1::DataType dataType
Check if FMHA can be implemented for given configuration.
- Parameters:
headSize – Attention head dimension
sm – CUDA compute capability
dataType – Data type
- Returns:
True if implementation is available
- static bool loadContextFMHAKernels(
- int32_t sm,
- nvinfer1::DataType dataType
Load FMHA kernel cubins into device.
- Parameters:
sm – CUDA compute capability
dataType – Data type
- Returns:
True if successful