Nvfp4 Moe Plugin#
-
class Nvfp4MoePlugin : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuild, public nvinfer1::IPluginV3OneRuntime#
TensorRT plugin: Nemotron-style MoE MLP with NVFP4 weights.
Per expert:
y_e= down_proj( act( up_proj(x) ) ). No separate gate projection.Dispatch:
enqueue()picks between two execution paths based on runtime token count (numTokens= B × S):numTokens≤ 16 — decode path (router + W4A16 GEMV, per-token GEMV kernels).numTokens> 16 — prefill path (router + GPU layout build + fp4Quantize + gather + FC1 grouped GEMM + fp4Quantize + FC2 grouped GEMM + scatter-reduce).
Routing: both paths consume pre-activation router logits
[num_tokens, num_experts] and dispatch to one of two router kernels, selected by therouting_modeattribute:0(kSOFTMAX_TOPK, default):moeTopkSoftmax(softmax + flat top-k + renormalize).1(kSIGMOID_GROUP_TOPK):moeSigmoidGroupTopk(sigmoid + grouped top-k + renormalize + scale). Usesn_group/topk_group/norm_topk_prob/routed_scaling_factorattributes.e_score_correction_bias[E] FP32 is used as an optional bias bymoeTopkSoftmax(mode 0) and as the expert load-balancing bias bymoeSigmoidGroupTopk(mode 1); pass zeros when no bias is desired.
Hidden activations are FP16 only. Any activation NVFP4 quantization (payload + FP8 block scales) needed by the prefill path is computed inside the plugin via
fp4Quantize; the caller supplies only calibrated global scales.Public Functions
- Nvfp4MoePlugin(
- std::string const &name,
- int32_t numExperts,
- int32_t topK,
- int32_t hiddenSize,
- int32_t moeInterSize,
- nvinfer1::ActivationType activationType = static_cast<nvinfer1::ActivationType>(0),
- int32_t nGroup = 1,
- int32_t topkGroup = 1,
- int32_t normTopkProb = 1,
- float routedScalingFactor = 1.0f,
- int32_t routingMode = static_cast<int32_t>(Nvfp4MoeRoutingMode::kSOFTMAX_TOPK)
Build-phase constructor — populated from ONNX attributes.
mMaxTokensstarts at 0 and is set inconfigurePlugin().
- Nvfp4MoePlugin(
- std::string const &name,
- nvinfer1::PluginFieldCollection const *fc
Runtime deserialization constructor. Reads
max_tokensfrom the serialized field collection; a missing field throws.
-
Nvfp4MoePlugin() = delete#
-
Nvfp4MoePlugin(Nvfp4MoePlugin const&) = delete#
-
~Nvfp4MoePlugin() noexcept override#
- nvinfer1::IPluginCapability *getCapabilityInterface(
- nvinfer1::PluginCapabilityType type
-
nvinfer1::IPluginV3 *clone() noexcept override#
-
char const *getPluginName() const noexcept override#
-
char const *getPluginVersion() const noexcept override#
-
char const *getPluginNamespace() const noexcept override#
-
int32_t getNbOutputs() const noexcept override#
- int32_t getOutputDataTypes(
- nvinfer1::DataType *outputTypes,
- int32_t nbOutputs,
- nvinfer1::DataType const *inputTypes,
- int32_t nbInputs
- int32_t getOutputShapes(
- nvinfer1::DimsExprs const *inputs,
- int32_t nbInputs,
- nvinfer1::DimsExprs const *shapeInputs,
- int32_t nbShapeInputs,
- nvinfer1::DimsExprs *outputs,
- int32_t nbOutputs,
- nvinfer1::IExprBuilder &exprBuilder
- bool supportsFormatCombination(
- int32_t pos,
- nvinfer1::DynamicPluginTensorDesc const *inOut,
- int32_t nbInputs,
- int32_t nbOutputs
- int32_t configurePlugin(
- nvinfer1::DynamicPluginTensorDesc const *in,
- int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const *out,
- int32_t nbOutputs
- size_t getWorkspaceSize(
- nvinfer1::DynamicPluginTensorDesc const *inputs,
- int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const *outputs,
- int32_t nbOutputs
- int32_t enqueue(
- nvinfer1::PluginTensorDesc const *inputDesc,
- nvinfer1::PluginTensorDesc const *outputDesc,
- void const *const *inputs,
- void *const *outputs,
- void *workspace,
- cudaStream_t stream
- int32_t onShapeChange(
- nvinfer1::PluginTensorDesc const *in,
- int32_t nbInputs,
- nvinfer1::PluginTensorDesc const *out,
- int32_t nbOutputs
- nvinfer1::IPluginV3 *attachToContext(
- nvinfer1::IPluginResourceContext *context
- nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
-
class Nvfp4MoePluginCreator : public nvinfer1::IPluginCreatorV3One#
Public Functions
-
Nvfp4MoePluginCreator()#
-
~Nvfp4MoePluginCreator() override = default#
-
char const *getPluginName() const noexcept override#
-
char const *getPluginVersion() const noexcept override#
- nvinfer1::PluginFieldCollection const *getFieldNames(
-
char const *getPluginNamespace() const noexcept override#
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
- nvinfer1::IPluginV3 *createPlugin(
- char const *name,
- nvinfer1::PluginFieldCollection const *fc,
- nvinfer1::TensorRTPhase phase
-
Nvfp4MoePluginCreator()#