Decoder XQA Runner#

class DecoderXQARunner#

Decoder XQA (eXtended Query Attention) kernel runner.

Public Functions

DecoderXQARunner(
nvinfer1::DataType const dataType,
int32_t batchSize,
int32_t numQHeads,
int32_t numKvHeads,
int32_t headSize,
int32_t smVersion
)#

Constructor for DecoderXQARunner.

Parameters:
  • dataType[in] Data type for computation

  • batchSize[in] Batch size

  • numQHeads[in] Number of query heads

  • numKvHeads[in] Number of key-value heads

  • headSize[in] Head dimension size

  • smVersion[in] CUDA SM version

DecoderXQARunner() = default#
~DecoderXQARunner() = default#
void dispatchXQAKernel(
XQALaunchParams &params,
cudaStream_t const &stream
)#

Dispatch XQA kernel and compute the attention result.

Parameters:
  • params[inout] Launch parameters for XQA kernel

  • stream[in] CUDA stream for kernel execution

void dispatchSpecDecodeXQAKernel(
XQALaunchParams &params,
cudaStream_t const &stream
)#

Dispatch spec-decode XQA kernel for tree attention.

Parameters:
  • params[inout] Launch parameters for XQA kernel

  • stream[in] CUDA stream for kernel execution

XQALaunchParams initXQAParams()#

Initialize XQA parameters with MHA and hardware configuration.

The XQA parameter can be used by prepareToRun() to query kernel to dispatch. Device pointer shall be setup by caller to dispatch XQA kernel.

Returns:

Initialized XQA launch parameters

Public Static Functions

static bool canImplement(
int32_t numQHeads,
int32_t numKVHeads,
int32_t smVersion,
nvinfer1::DataType dataType
)#

Check if XQA kernel can be implemented with given configuration.

Parameters:
  • numQHeads[in] Number of query heads

  • numKVHeads[in] Number of key-value heads

  • smVersion[in] CUDA SM version

  • dataType[in] Data type for computation

Returns:

True if implementation is supported, false otherwise

static bool loadDecodeXQAKernels(
int32_t smVersion,
nvinfer1::DataType dataType,
bool useSpecDecodeKernels
)#

Load decoder XQA kernels for given configuration.

Parameters:
  • smVersion[in] CUDA SM version

  • dataType[in] Data type for computation

  • useSpecDecodeKernels[in] Whether to load spec-decode kernels

Returns:

True if kernels loaded successfully, false otherwise

struct XQALaunchParams#

Launch parameters for XQA (eXtended Query Attention) kernel.

Public Members

void *output = nullptr#

Device memory pointers to launch XQA kernel.

Output tensor

void const *qInputPtr = nullptr#

Query input pointer.

KVCache kvCache#

KV cache structure.

float const *kvScale = nullptr#

KV scaling factors.

int32_t *semaphores = nullptr#

Semaphores for synchronization.

void *scratch = nullptr#

Scratch memory.

void *treeAttnMask = nullptr#

Unique device memory pointer for spec-decode tree attention.

Tree attention mask

int32_t *qCuSeqLen = nullptr#

Cumulative query sequence lengths.

float const *attentionSinks = nullptr#

Attention sinks parameter.

int32_t numQheads = 0#

MHA parameters to locate a kernel to launch.

Number of query heads

int32_t numKVheads = 0#

Number of key-value heads.

int32_t headSize = 0#

Head dimension size.

int32_t batchSize = 0#

Batch size.

int32_t qSeqLen = 0#

Parameters for spec-decode tree attention.

Query sequence length

float qScale = 1.0F#

Query scaling factor.

int32_t headGroupSize = 0#

Head group size.

nvinfer1::DataType dataType#

I/O data type of the kernel.