Context FMHA Runner#

class ContextFMHARunner#

Runner for context-phase fused multi-head attention (FMHA)

Manages and dispatches optimized FMHA kernels for the prefill/context phase of LLM inference. Supports various attention patterns (causal, sliding window) and memory layouts (packed QKV, separate Q/KV, paged KV cache).

Public Functions

ContextFMHARunner(
nvinfer1::DataType const dataType,
int32_t batchSize,
int32_t paddedSeqLen,
int32_t numQHeads,
int32_t numKvHeads,
int32_t headSize,
int32_t smVersion,
AttentionInputLayout inputLayout
)#

Construct context FMHA runner.

Parameters:
  • dataType – Data type (e.g., FP16, BF16)

  • batchSize – Batch size

  • paddedSeqLen – Padded sequence length

  • numQHeads – Number of query heads

  • numKvHeads – Number of key-value heads

  • headSize – Attention head dimension

  • smVersion – CUDA compute capability (e.g., 89 for SM 8.9)

  • inputLayout – Input tensor layout

ContextFMHARunner() = delete#

Deleted default constructor.

~ContextFMHARunner() = default#

Destructor.

size_t getWorkspaceSize()#

Get required workspace size in bytes.

Returns:

Workspace size

void setupParams(FusedMultiheadAttentionParamsV2 &params)#

Setup kernel parameters (excluding device pointers)

Configures FMHA parameters. Device pointers must be set by caller.

Parameters:

params – FMHA parameter structure

void dispatchFMHAKernel(
FusedMultiheadAttentionParamsV2 &params,
cudaStream_t const &stream
)#

Dispatch FMHA kernel execution.

Parameters:
  • params – FMHA parameters with device pointers set

  • stream – CUDA stream for kernel launch

Public Static Functions

static bool canImplement(
int32_t headSize,
int32_t sm,
nvinfer1::DataType dataType
)#

Check if FMHA can be implemented for given configuration.

Parameters:
  • headSize – Attention head dimension

  • sm – CUDA compute capability

  • dataType – Data type

Returns:

True if implementation is available

static bool loadContextFMHAKernels(
int32_t sm,
nvinfer1::DataType dataType
)#

Load FMHA kernel cubins into device.

Parameters:
  • sm – CUDA compute capability

  • dataType – Data type

Returns:

True if successful