Nvfp4 Tensor#
-
namespace trt_edgellm
Typedefs
-
using Dim3 = int3#
Three tile indices for NVFP4Tensor:
x→strides[0] (slowest),y→strides[1],z→strides[2] (fastest).
Functions
- __device__ __forceinline__ float nvfp4TensorScaleAt (float const *global_scale, int const idx)
global_scale[idx] ifglobal_scaleis non-null, else1.f.
Variables
-
int kNvfp4ElemsPerTile = 64#
Packed NVFP4 (E2M1) payload for one 64-element tile: 8×uint32 lanes, 8 FP4 values per uint32.
-
int64_t kNvfp4Int4PerTilePayload = 2#
Number of
int4vectors per tile payload (8×uint32 lanes == 2×int4);int64_tfor mixedint4/int64stride math.
-
struct NVFP4Tensor#
Device view of an NVFP4 tensor stored in 64-element tiles with per-tile FP8 block scales and optional global_scale (applied after block dequantization).
global_scale is a device pointer to
Efloats from the MoE plugin (per-expert scales) when the tensor is stacked expert weights (slow tile indexxis the expert id); kernels useglobal_scale[tile.x].nullptrmeans a factor of1.f. MoE W4A4 activation:global_scale[0] scales the activation in the up-proj GEMV (after NVFP4 dequant). Down-proj decode does not read the activation tensor.nullptrmeans1.ffor index0.Logical layout is 3D in tiles (each tile is kNvfp4ElemsPerTile NVFP4 values + one packed scale int). Index order is row-major C:
zis the fastest-varying tile index,xthe slowest.quantized_datapoints at the firstint4of the tile at logical origin (0,0,0). Each tile occupies kNvfp4Int4PerTilePayload contiguousint4vectors.strides[d] is the offset inint4elements when logical tile index componentdincreases by 1 (NumPy-style / sizeof(int4)).block_scaleuses atom-layout 128×4 swizzle (matching Blackwelltcgen05.mmablock-scaled MMA and TMA scale-factor multicasts). Each FP8 E4M3 byte is stored at a byte offset computed by tileScaleIndex; four consecutive bytes (innerK 0..3) form oneint32word read by readBlockScaleWord.Atom-layout byte offset for scale factor at logical (mRow, kChunk): innerK = kChunk % 4; innerM = (mRow % 128) / 32; outerM = mRow % 32; kTile = kChunk / 4; mTile = mRow / 128; byteOffset = mTile * numKTiles * 512 + kTile * 512 + outerM * 16 + innerM * 4 + innerK
The int32 word index (for
block_scalepointer) drops innerK: mTile * numKTiles * 128 + kTile * 128 + outerM * 4 + innerMscaleMDimIdx/scaleKDimIdxselect whichDim3component maps to mRow / kChunk.scaleNumKTilesisceil(numSfCols / 4) wherenumSfCols= K / 16.scaleExpertStrideis the int32 element stride between experts (ceil(M/128) * numKTiles * 128).Public Functions
- inline __device__ __forceinline__ int64_t tileOffsetInt4 (Dim3 const c) const
- inline __device__ __forceinline__ int64_t tileScaleIndex (Dim3 const c) const
Atom-layout 128×4 swizzle: returns
int32index intoblock_scalefor the tile atc.
- inline __device__ __forceinline__ int readBlockScaleWord (Dim3 const c) const
Read block-scale int32 word with byte-swap for Marlin dequant compatibility. Atom stores sequential FP8 bytes {s0,s1,s2,s3}; Marlin
dequant_fp8_scalesexpects {s0,s2,s1,s3}.
- inline __device__ __forceinline__ int readBlockScaleWordLinear (Dim3 const c) const
Read block-scale int32 word using plain linear tile indexing (one scale word per tile payload). The word is already in Marlin-packed byte order — no byte-swap. Used for activation scale factors whose layout matches
TRT_DynamicQuantizeoutput (contiguous, not atom-swizzled).
- inline __device__ __forceinline__ void loadTileUint4 (Dim3 const tile, int const chunk, uint4 &out) const
Loads one
int4chunk (4×uint32NVFP4 lane packs) from the tile payload.- Parameters:
chunk –
0or1— tile payload is kNvfp4Int4PerTilePayloadint4vectors.
- inline __device__ __forceinline__ uint32_t loadTileUint32Lane (Dim3 const tile, int const chunk, int const lane) const
Loads one
uint32lane (8 packed NVFP4 values) from the tile payload.- Parameters:
chunk –
0or1;lane –
0..3within that chunk (see loadTileUint4).
Public Members
-
int4 *quantized_data#
-
int *block_scale#
-
float *global_scale#
Device pointer to
Eper-expert scales (MoE plugin inputs), ornullptr. See struct comment.
-
int64_t strides[3]#
Stride in
int4elements per +1 in tile index dim 0..2 (zfastest).
-
int32_t scaleNumKTiles#
Atom-layout scale factor metadata.
K-tiles for atom swizzle: ceil(numSfCols / 4) where numSfCols = K_dim / 16.
-
int64_t scaleExpertStride#
int32 elements between experts: ceil(M/128) * scaleNumKTiles * 128.
-
int8_t scaleMDimIdx#
Which Dim3 component is the M-row (0 for act, 1 for weights).
-
int8_t scaleKDimIdx#
Which Dim3 component is the K-chunk (1 for act, 2 for weights).
-
using Dim3 = int3#