Moe Marlin#
- void trt_edgellm::kernel::moeAwqW4A16MarlinGemm(
- rt::Tensor const &input,
- rt::Tensor &output,
- rt::Tensor const &weights,
- rt::Tensor const &scales,
- rt::Tensor const &sortedTokenIds,
- rt::Tensor const &expertIds,
- rt::Tensor const &numTokensPostPadded,
- rt::Tensor const &topkWeights,
- rt::Tensor &workspace,
- int64_t moeBlockSize,
- int64_t topK,
- bool mulTopkWeights,
- cudaStream_t stream
MoE W4A16 GEMM using Marlin kernel with AWQ format.
Performs MoE grouped GEMM with 4-bit AWQ quantized weights and 16-bit activations. Uses kU4 format where zero point = 8 is baked into dequantization.
Dequantization: weight_fp16 = (weight_int4 - 8) * scale
Note: Weights must be pre-swizzled into Marlin format using awq_marlin_repack.
Internally uses FP32 reduction for numerical accuracy. The workspace buffer must be sized using getMoeMarlinWorkspaceSize() which includes space for both synchronization locks and the FP32 reduction buffer.
- Parameters:
input – Input activations [numTokens, hiddenDim] (FP16)
output – Output tensor [numTokens * topK, outDim] (FP16)
weights – Marlin-repacked INT4 weights [numExperts, K/tile, N*tile/pack]
scales – Per-group scales [numExperts, numGroups, outDim] (FP16)
sortedTokenIds – Sorted token indices [numTokensPadded] (INT32)
expertIds – Expert assignment per block [numBlocks] (INT32)
numTokensPostPadded – Total padded token count [1] (INT32)
topkWeights – Routing weights [numTokensPadded] (FP32)
workspace – Workspace buffer sized by getMoeMarlinWorkspaceSize() (INT32)
moeBlockSize – MoE processing block size (8, 16, 32, 48, or 64)
topK – Number of experts per token
mulTopkWeights – Whether to multiply output by topk weights
stream – CUDA stream
- int64_t trt_edgellm::kernel::getMoeMarlinWorkspaceSize(
- int64_t numTokensPadded,
- int64_t outDim,
- int64_t moeBlockSize,
- int64_t numSMs
Get required workspace size for MoE Marlin GEMM.
Returns the total workspace size needed, which includes:
Synchronization locks for thread blocks
FP32 reduction buffer (c_tmp) for numerical accuracy
- Parameters:
numTokensPadded – Maximum padded token count
outDim – Output dimension N
moeBlockSize – MoE block size
numSMs – Number of SMs on the device
- Returns:
Required workspace size in number of int32_t elements