Nvfp4 Moe Plugin#

class Nvfp4MoePlugin : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuild, public nvinfer1::IPluginV3OneRuntime#

TensorRT plugin: Nemotron-style MoE MLP with NVFP4 weights.

Per expert: y_e = down_proj( act( up_proj(x) ) ). No separate gate projection.

Dispatch: enqueue() picks between two execution paths based on runtime token count (numTokens = B × S):

  • numTokens ≤ 16 — decode path (router + W4A16 GEMV, per-token GEMV kernels).

  • numTokens > 16 — prefill path (router + GPU layout build + fp4Quantize + gather + FC1 grouped GEMM + fp4Quantize + FC2 grouped GEMM + scatter-reduce).

Routing: both paths consume pre-activation router logits [num_tokens, num_experts] and dispatch to one of two router kernels, selected by the routing_mode attribute:

  • 0 (kSOFTMAX_TOPK, default): moeTopkSoftmax (softmax + flat top-k + renormalize).

  • 1 (kSIGMOID_GROUP_TOPK): moeSigmoidGroupTopk (sigmoid + grouped top-k + renormalize + scale). Uses n_group / topk_group / norm_topk_prob / routed_scaling_factor attributes. e_score_correction_bias [E] FP32 is used as an optional bias by moeTopkSoftmax (mode 0) and as the expert load-balancing bias by moeSigmoidGroupTopk (mode 1); pass zeros when no bias is desired.

Hidden activations are FP16 only. Any activation NVFP4 quantization (payload + FP8 block scales) needed by the prefill path is computed inside the plugin via fp4Quantize; the caller supplies only calibrated global scales.

Public Functions

Nvfp4MoePlugin(
std::string const &name,
int32_t numExperts,
int32_t topK,
int32_t hiddenSize,
int32_t moeInterSize,
nvinfer1::ActivationType activationType = static_cast<nvinfer1::ActivationType>(0),
int32_t nGroup = 1,
int32_t topkGroup = 1,
int32_t normTopkProb = 1,
float routedScalingFactor = 1.0f,
int32_t routingMode = static_cast<int32_t>(Nvfp4MoeRoutingMode::kSOFTMAX_TOPK)
)#

Build-phase constructor — populated from ONNX attributes. mMaxTokens starts at 0 and is set in configurePlugin().

Nvfp4MoePlugin(
std::string const &name,
nvinfer1::PluginFieldCollection const *fc
)#

Runtime deserialization constructor. Reads max_tokens from the serialized field collection; a missing field throws.

Nvfp4MoePlugin() = delete#
Nvfp4MoePlugin(Nvfp4MoePlugin const&) = delete#
~Nvfp4MoePlugin() noexcept override#
nvinfer1::IPluginCapability *getCapabilityInterface(
nvinfer1::PluginCapabilityType type
) noexcept override#
nvinfer1::IPluginV3 *clone() noexcept override#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
char const *getPluginNamespace() const noexcept override#
int32_t getNbOutputs() const noexcept override#
int32_t getOutputDataTypes(
nvinfer1::DataType *outputTypes,
int32_t nbOutputs,
nvinfer1::DataType const *inputTypes,
int32_t nbInputs
) const noexcept override#
int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
int32_t nbInputs,
nvinfer1::DimsExprs const *shapeInputs,
int32_t nbShapeInputs,
nvinfer1::DimsExprs *outputs,
int32_t nbOutputs,
nvinfer1::IExprBuilder &exprBuilder
) noexcept override#
bool supportsFormatCombination(
int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *inOut,
int32_t nbInputs,
int32_t nbOutputs
) noexcept override#
int32_t configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
size_t getWorkspaceSize(
nvinfer1::DynamicPluginTensorDesc const *inputs,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *outputs,
int32_t nbOutputs
) const noexcept override#
int32_t enqueue(
nvinfer1::PluginTensorDesc const *inputDesc,
nvinfer1::PluginTensorDesc const *outputDesc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream
) noexcept override#
int32_t onShapeChange(
nvinfer1::PluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
nvinfer1::IPluginV3 *attachToContext(
nvinfer1::IPluginResourceContext *context
) noexcept override#
nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
class Nvfp4MoePluginCreator : public nvinfer1::IPluginCreatorV3One#

Public Functions

Nvfp4MoePluginCreator()#
~Nvfp4MoePluginCreator() override = default#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#
char const *getPluginNamespace() const noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
nvinfer1::IPluginV3 *createPlugin(
char const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase
) noexcept override#