Mtp State Scatter Kernels#

struct MtpLayerInfo#

Per-layer pointer bundle for batched MTP state scatter (analogous to KVLayerInfo).

Public Members

void *recurrentDst#

Main recurrent state pool [batchSize, recElements] FP32.

void *recurrentSrc#

Intermediate recurrent state buffer [batchSize, verifyTreeSize, recElements] FP32.

void *convDst#

Main conv state pool [batchSize, convElements] FP16.

void *convSrc#

Intermediate conv state buffer [batchSize, verifyTreeSize, convElements] FP16.

void trt_edgellm::kernel::mtpScatterRecurrentStates(
MtpLayerInfo const *deviceLayerInfos,
int32_t numLayers,
int32_t activeBatchSize,
int32_t verifyTreeSize,
int32_t stateElements,
int32_t const *acceptLengths,
cudaStream_t stream
)#

Batched scatter of accepted recurrent states (FP32) across all GDN layers. For each (layer, batch): dst[b, :] = src[b, acceptLengths[b] - 1, :]. Skip if acceptLengths[b] <= 0; no-op if acceptLengths[b] >= verifyTreeSize. stateElements must be divisible by 8 (DVec<float> vec_size).

void trt_edgellm::kernel::mtpScatterConvStates(
MtpLayerInfo const *deviceLayerInfos,
int32_t numLayers,
int32_t activeBatchSize,
int32_t verifyTreeSize,
int32_t stateElements,
int32_t const *acceptLengths,
cudaStream_t stream
)#

Batched scatter of accepted conv states (FP16). Same shape as recurrent variant.