Nv Fp4 Mo E Contiguous GEMM Runner#
-
class NvFP4MoEContiguousGemmRunner#
Public Functions
- NvFP4MoEContiguousGemmRunner(
- int32_t numLocalExperts,
- int32_t topK,
- int32_t n,
- int32_t k,
- int32_t tileSize = 128,
- Activation activation = Activation::kRelu2,
- OutputDType outDtype = OutputDType::kBF16
- Parameters:
numLocalExperts – Number of local experts (L)
topK – Routing factor
n – Intermediate size (N)
k – Hidden size (K)
tileSize – Tile size (128)
activation – Activation function (compiled into the AOT binary). Only Relu2 and Swiglu are compiled; Identity is not exported as it has no production use for FC1.
outDtype – Output element type (selects the AOT variant).
- void run(
- void const *gatheredFP4,
- void const *weight,
- void const *gatheredSF,
- void const *weightSF,
- void *output,
- void const *alpha,
- MoELayout const &layout,
- int64_t permutedM,
- cudaStream_t stream
Run the contiguous grouped GEMM with fused alpha + activation.
Unlike the bucketed runner, this takes the layout directly — no per-group metadata construction needed. Alpha scaling and activation are applied inside the kernel epilogue in float32.
- Parameters:
gatheredFP4 – [permutedM, K/2] float4_e2m1fn_x2 on device
weight – [L, K, N/2] float4_e2m1fn_x2 on device (3D stacked, N-major byte layout — N axis innermost, 2 FP4 nibbles per byte along N). Matches the plugin v4 fc_up_qweights shape and the Marlin decode layout.
gatheredSF – atom-layout SF buffer on device (input A scales)
weightSF – atom-layout SF buffer on device (weight B scales, prefill-friendly M=N, K=K/16 — unchanged from v3)
output – [permutedM, N_out] bfloat16 on device (output)
alpha – [L] float32 per-expert scaling on device
layout – MoE layout (tile metadata + permutation indices)
permutedM – Total permuted rows
stream – CUDA stream