Nvfp4 Moe Plugin Geforce#

class NvFP4MoEPluginGeforce : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuild, public nvinfer1::IPluginV3OneRuntime#

TensorRT plugin: NVFP4 fused MoE (CuTeDSL SM120/SM121) — FP16 activations, dynamic on-the-fly NVFP4 quant, fused route/pack + FC1 + activation + quant + FC2 + scatter.

Per expert: y_e = down_proj( act( up_proj(x) ) ) with NVFP4 packed weights.

Weight layout: FC1 is the plain [up_all, gate_all] concat along the M axis (no 64-row up/gate interleave). This matches what the fused SM12x CuTeDSL kernel expects natively. The SM110 Nvfp4MoePlugin uses the separate 64-row interleaved layout consumed by the split FC1/FC2 backend.

Note

This plugin is only supported on SM120 and SM121 (consumer Blackwell).

Note

This plugin is only supported on FP16 I/O.

Note

Supported activations: identity, silu, swiglu, gelu, relu2.

Public Functions

NvFP4MoEPluginGeforce(
std::string const &name,
int32_t numExperts,
int32_t topK,
int32_t hiddenSize,
int32_t moeInterSize,
int32_t activationType,
int32_t nGroup,
int32_t topkGroup,
int32_t normTopkProb,
float routedScalingFactor,
int32_t routingMode,
int32_t backend,
int32_t maxRoutedRows,
int32_t ioDtype
)#
NvFP4MoEPluginGeforce(
std::string const &name,
nvinfer1::PluginFieldCollection const *fc
)#
NvFP4MoEPluginGeforce() = delete#
NvFP4MoEPluginGeforce(NvFP4MoEPluginGeforce const&) = delete#
~NvFP4MoEPluginGeforce() noexcept override#
nvinfer1::IPluginCapability *getCapabilityInterface(
nvinfer1::PluginCapabilityType type
) noexcept override#
nvinfer1::IPluginV3 *clone() noexcept override#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
char const *getPluginNamespace() const noexcept override#
int32_t getNbOutputs() const noexcept override#
int32_t getOutputDataTypes(
nvinfer1::DataType *outputTypes,
int32_t nbOutputs,
nvinfer1::DataType const *inputTypes,
int32_t nbInputs
) const noexcept override#
int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
int32_t nbInputs,
nvinfer1::DimsExprs const *shapeInputs,
int32_t nbShapeInputs,
nvinfer1::DimsExprs *outputs,
int32_t nbOutputs,
nvinfer1::IExprBuilder &exprBuilder
) noexcept override#
bool supportsFormatCombination(
int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *inOut,
int32_t nbInputs,
int32_t nbOutputs
) noexcept override#
int32_t configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
size_t getWorkspaceSize(
nvinfer1::DynamicPluginTensorDesc const *inputs,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *outputs,
int32_t nbOutputs
) const noexcept override#
int32_t enqueue(
nvinfer1::PluginTensorDesc const *inputDesc,
nvinfer1::PluginTensorDesc const *outputDesc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream
) noexcept override#
int32_t onShapeChange(
nvinfer1::PluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
nvinfer1::IPluginV3 *attachToContext(
nvinfer1::IPluginResourceContext *context
) noexcept override#
nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
class NvFP4MoEPluginGeforceCreator : public nvinfer1::IPluginCreatorV3One#

Plugin creator — parses PluginFieldCollection into the attributes above, registers under TensorRT’s default namespace, exposes name “NvFP4MoEPluginGeforce” / version “1”.

Public Functions

NvFP4MoEPluginGeforceCreator()#
~NvFP4MoEPluginGeforceCreator() override = default#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#
char const *getPluginNamespace() const noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
nvinfer1::IPluginV3 *createPlugin(
char const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase
) noexcept override#