Moe Marlin#

void trt_edgellm::kernel::moeAwqW4A16MarlinGemm(
rt::Tensor const &input,
rt::Tensor &output,
rt::Tensor const &weights,
rt::Tensor const &scales,
rt::Tensor const &sortedTokenIds,
rt::Tensor const &expertIds,
rt::Tensor const &numTokensPostPadded,
rt::Tensor const &topkWeights,
rt::Tensor &workspace,
int64_t moeBlockSize,
int64_t topK,
bool mulTopkWeights,
cudaStream_t stream
)#

MoE W4A16 GEMM using Marlin kernel with AWQ format.

Performs MoE grouped GEMM with 4-bit AWQ quantized weights and 16-bit activations. Uses kU4 format where zero point = 8 is baked into dequantization.

Dequantization: weight_fp16 = (weight_int4 - 8) * scale

Note: Weights must be pre-swizzled into Marlin format using awq_marlin_repack.

Internally uses FP32 reduction for numerical accuracy. The workspace buffer must be sized using getMoeMarlinWorkspaceSize() which includes space for both synchronization locks and the FP32 reduction buffer.

Parameters:
  • input – Input activations [numTokens, hiddenDim] (FP16)

  • output – Output tensor [numTokens * topK, outDim] (FP16)

  • weights – Marlin-repacked INT4 weights [numExperts, K/tile, N*tile/pack]

  • scales – Per-group scales [numExperts, numGroups, outDim] (FP16)

  • sortedTokenIds – Sorted token indices [numTokensPadded] (INT32)

  • expertIds – Expert assignment per block [numBlocks] (INT32)

  • numTokensPostPadded – Total padded token count [1] (INT32)

  • topkWeights – Routing weights [numTokensPadded] (FP32)

  • workspace – Workspace buffer sized by getMoeMarlinWorkspaceSize() (INT32)

  • moeBlockSize – MoE processing block size (8, 16, 32, 48, or 64)

  • topK – Number of experts per token

  • mulTopkWeights – Whether to multiply output by topk weights

  • stream – CUDA stream

int64_t trt_edgellm::kernel::getMoeMarlinWorkspaceSize(
int64_t numTokensPadded,
int64_t outDim,
int64_t moeBlockSize,
int64_t numSMs
)#

Get required workspace size for MoE Marlin GEMM.

Returns the total workspace size needed, which includes:

  • Synchronization locks for thread blocks

  • FP32 reduction buffer (c_tmp) for numerical accuracy

Parameters:
  • numTokensPadded – Maximum padded token count

  • outDim – Output dimension N

  • moeBlockSize – MoE block size

  • numSMs – Number of SMs on the device

Returns:

Required workspace size in number of int32_t elements