Mtp State Scatter Kernels#
-
struct MtpLayerInfo#
Per-layer pointer bundle for batched MTP state scatter (analogous to KVLayerInfo).
Public Members
-
void *recurrentDst#
Main recurrent state pool [batchSize, recElements] FP32.
-
void *recurrentSrc#
Intermediate recurrent state buffer [batchSize, verifyTreeSize, recElements] FP32.
-
void *convDst#
Main conv state pool [batchSize, convElements] FP16.
-
void *convSrc#
Intermediate conv state buffer [batchSize, verifyTreeSize, convElements] FP16.
-
void *recurrentDst#
- void trt_edgellm::kernel::mtpScatterRecurrentStates(
- MtpLayerInfo const *deviceLayerInfos,
- int32_t numLayers,
- int32_t activeBatchSize,
- int32_t verifyTreeSize,
- int32_t stateElements,
- int32_t const *acceptLengths,
- cudaStream_t stream
Batched scatter of accepted recurrent states (FP32) across all GDN layers. For each (layer, batch): dst[b, :] = src[b, acceptLengths[b] - 1, :]. Skip if acceptLengths[b] <= 0; no-op if acceptLengths[b] >= verifyTreeSize. stateElements must be divisible by 8 (DVec<float> vec_size).
- void trt_edgellm::kernel::mtpScatterConvStates(
- MtpLayerInfo const *deviceLayerInfos,
- int32_t numLayers,
- int32_t activeBatchSize,
- int32_t verifyTreeSize,
- int32_t stateElements,
- int32_t const *acceptLengths,
- cudaStream_t stream
Batched scatter of accepted conv states (FP16). Same shape as recurrent variant.