Nvfp4 Tensor#

namespace trt_edgellm

Typedefs

using Dim3 = int3#

Three tile indices for NVFP4Tensor: xstrides[0] (slowest), ystrides[1], zstrides[2] (fastest).

Functions

__device__ __forceinline__ float nvfp4TensorScaleAt (float const *global_scale, int const idx)

global_scale[idx] if global_scale is non-null, else 1.f.

Variables

int kNvfp4ElemsPerTile = 64#

Packed NVFP4 (E2M1) payload for one 64-element tile: 8×uint32 lanes, 8 FP4 values per uint32.

int64_t kNvfp4Int4PerTilePayload = 2#

Number of int4 vectors per tile payload (8×uint32 lanes == 2×int4); int64_t for mixed int4 / int64 stride math.

struct NVFP4Tensor#

Device view of an NVFP4 tensor stored in 64-element tiles with per-tile FP8 block scales and optional global_scale (applied after block dequantization).

global_scale is a device pointer to E floats from the MoE plugin (per-expert scales) when the tensor is stacked expert weights (slow tile index x is the expert id); kernels use global_scale[tile.x]. nullptr means a factor of 1.f. MoE W4A4 activation: global_scale[0] scales the activation in the up-proj GEMV (after NVFP4 dequant). Down-proj decode does not read the activation tensor. nullptr means 1.f for index 0.

Logical layout is 3D in tiles (each tile is kNvfp4ElemsPerTile NVFP4 values + one packed scale int). Index order is row-major C: z is the fastest-varying tile index, x the slowest.

quantized_data points at the first int4 of the tile at logical origin (0,0,0). Each tile occupies kNvfp4Int4PerTilePayload contiguous int4 vectors. strides[d] is the offset in int4 elements when logical tile index component d increases by 1 (NumPy-style / sizeof(int4)).

block_scale uses atom-layout 128×4 swizzle (matching Blackwell tcgen05.mma block-scaled MMA and TMA scale-factor multicasts). Each FP8 E4M3 byte is stored at a byte offset computed by tileScaleIndex; four consecutive bytes (innerK 0..3) form one int32 word read by readBlockScaleWord.

Atom-layout byte offset for scale factor at logical (mRow, kChunk): innerK = kChunk % 4; innerM = (mRow % 128) / 32; outerM = mRow % 32; kTile = kChunk / 4; mTile = mRow / 128; byteOffset = mTile * numKTiles * 512 + kTile * 512 + outerM * 16 + innerM * 4 + innerK

The int32 word index (for block_scale pointer) drops innerK: mTile * numKTiles * 128 + kTile * 128 + outerM * 4 + innerM

scaleMDimIdx / scaleKDimIdx select which Dim3 component maps to mRow / kChunk. scaleNumKTiles is ceil(numSfCols / 4) where numSfCols = K / 16. scaleExpertStride is the int32 element stride between experts (ceil(M/128) * numKTiles * 128).

Public Functions

inline __device__ __forceinline__ int64_t tileOffsetInt4 (Dim3 const c) const
inline __device__ __forceinline__ int64_t tileScaleIndex (Dim3 const c) const

Atom-layout 128×4 swizzle: returns int32 index into block_scale for the tile at c.

inline __device__ __forceinline__ int readBlockScaleWord (Dim3 const c) const

Read block-scale int32 word with byte-swap for Marlin dequant compatibility. Atom stores sequential FP8 bytes {s0,s1,s2,s3}; Marlin dequant_fp8_scales expects {s0,s2,s1,s3}.

inline __device__ __forceinline__ int readBlockScaleWordLinear (Dim3 const c) const

Read block-scale int32 word using plain linear tile indexing (one scale word per tile payload). The word is already in Marlin-packed byte order — no byte-swap. Used for activation scale factors whose layout matches TRT_DynamicQuantize output (contiguous, not atom-swizzled).

inline __device__ __forceinline__ void loadTileUint4 (Dim3 const tile, int const chunk, uint4 &out) const

Loads one int4 chunk ( uint32 NVFP4 lane packs) from the tile payload.

Parameters:

chunk0 or 1 — tile payload is kNvfp4Int4PerTilePayload int4 vectors.

inline __device__ __forceinline__ uint32_t loadTileUint32Lane (Dim3 const tile, int const chunk, int const lane) const

Loads one uint32 lane (8 packed NVFP4 values) from the tile payload.

Parameters:
  • chunk0 or 1;

  • lane0..3 within that chunk (see loadTileUint4).

Public Members

int4 *quantized_data#
int *block_scale#
float *global_scale#

Device pointer to E per-expert scales (MoE plugin inputs), or nullptr. See struct comment.

int64_t strides[3]#

Stride in int4 elements per +1 in tile index dim 0..2 (z fastest).

int32_t scaleNumKTiles#

Atom-layout scale factor metadata.

K-tiles for atom swizzle: ceil(numSfCols / 4) where numSfCols = K_dim / 16.

int64_t scaleExpertStride#

int32 elements between experts: ceil(M/128) * scaleNumKTiles * 128.

int8_t scaleMDimIdx#

Which Dim3 component is the M-row (0 for act, 1 for weights).

int8_t scaleKDimIdx#

Which Dim3 component is the K-chunk (1 for act, 2 for weights).