Decoder XQA Runner#
-
class DecoderXQARunner#
Decoder XQA (eXtended Query Attention) kernel runner.
Public Functions
- DecoderXQARunner(
- nvinfer1::DataType const dataType,
- int32_t batchSize,
- int32_t numQHeads,
- int32_t numKvHeads,
- int32_t headSize,
- int32_t smVersion
Constructor for DecoderXQARunner.
- Parameters:
dataType – [in] Data type for computation
batchSize – [in] Batch size
numQHeads – [in] Number of query heads
numKvHeads – [in] Number of key-value heads
headSize – [in] Head dimension size
smVersion – [in] CUDA SM version
-
DecoderXQARunner() = default#
-
~DecoderXQARunner() = default#
- void dispatchXQAKernel(
- XQALaunchParams ¶ms,
- cudaStream_t const &stream
Dispatch XQA kernel and compute the attention result.
- Parameters:
params – [inout] Launch parameters for XQA kernel
stream – [in] CUDA stream for kernel execution
- void dispatchSpecDecodeXQAKernel(
- XQALaunchParams ¶ms,
- cudaStream_t const &stream
Dispatch spec-decode XQA kernel for tree attention.
- Parameters:
params – [inout] Launch parameters for XQA kernel
stream – [in] CUDA stream for kernel execution
-
XQALaunchParams initXQAParams()#
Initialize XQA parameters with MHA and hardware configuration.
The XQA parameter can be used by prepareToRun() to query kernel to dispatch. Device pointer shall be setup by caller to dispatch XQA kernel.
- Returns:
Initialized XQA launch parameters
Public Static Functions
- static bool canImplement(
- int32_t numQHeads,
- int32_t numKVHeads,
- int32_t smVersion,
- nvinfer1::DataType dataType
Check if XQA kernel can be implemented with given configuration.
- Parameters:
numQHeads – [in] Number of query heads
numKVHeads – [in] Number of key-value heads
smVersion – [in] CUDA SM version
dataType – [in] Data type for computation
- Returns:
True if implementation is supported, false otherwise
- static bool loadDecodeXQAKernels(
- int32_t smVersion,
- nvinfer1::DataType dataType,
- bool useSpecDecodeKernels
Load decoder XQA kernels for given configuration.
- Parameters:
smVersion – [in] CUDA SM version
dataType – [in] Data type for computation
useSpecDecodeKernels – [in] Whether to load spec-decode kernels
- Returns:
True if kernels loaded successfully, false otherwise
-
struct XQALaunchParams#
Launch parameters for XQA (eXtended Query Attention) kernel.
Public Members
-
void *output = nullptr#
Device memory pointers to launch XQA kernel.
Output tensor
-
void const *qInputPtr = nullptr#
Query input pointer.
-
KVCache kvCache#
KV cache structure.
-
float const *kvScale = nullptr#
KV scaling factors.
-
int32_t *semaphores = nullptr#
Semaphores for synchronization.
-
void *scratch = nullptr#
Scratch memory.
-
void *treeAttnMask = nullptr#
Unique device memory pointer for spec-decode tree attention.
Tree attention mask
-
int32_t *qCuSeqLen = nullptr#
Cumulative query sequence lengths.
-
float const *attentionSinks = nullptr#
Attention sinks parameter.
-
int32_t numQheads = 0#
MHA parameters to locate a kernel to launch.
Number of query heads
-
int32_t numKVheads = 0#
Number of key-value heads.
-
int32_t headSize = 0#
Head dimension size.
-
int32_t batchSize = 0#
Batch size.
-
int32_t qSeqLen = 0#
Parameters for spec-decode tree attention.
Query sequence length
-
float qScale = 1.0F#
Query scaling factor.
-
int32_t headGroupSize = 0#
Head group size.
-
nvinfer1::DataType dataType#
I/O data type of the kernel.
-
void *output = nullptr#