Kernels#

namespace trt_edgellm

Enums

enum class MoEActivationKind : int8_t#

Intermediate activation after the up-proj dot-product (before post-nonlinearity scaling in decode kernels).

Values:

enumerator kReLU2 = 0#

max(z,0)^2

enumerator kSiLU = 1#

z * sigmoid(z)

Functions

inline bool moeW4a4DecodeIsValidThreadBlockSize(
int const thread_block_size
) noexcept#

Supported thread_block_size values for W4A4 decode GEMV kernels (multiple of warp size).

inline int nemotronMoeW4A16DecodeThreadBlockSizeForInterDim(
int const inter_dim
) noexcept#

Picks W4A16 decode block size from inter_dim: largest of 256, 128, 96, 64 that divides inter_dim.

Returns:

0 if inter_dim is not divisible by 64.

inline int nemotronMoeW4A4DecodeThreadBlockSizeForDims(
int const hidden_dim,
int const inter_dim
) noexcept#

Picks a thread_block_size valid for moeW4a4DecodeIsValidThreadBlockSize that divides both hidden_dim and inter_dim. W4A4 decode uses one block size for up (strips along hidden_dim) and down (strips along inter_dim).

Returns:

0 if no candidate divides both dimensions.

inline int moeDecodeGemvTopkGridDim(
int const batch_size,
int const top_k,
int const inter_dim,
int const thread_block_size
) noexcept#

Grid x-dimension for MoE decode GEMV with explicit top-k routing: one block per (token row, top-k slot, thread_block_size-sized strip of inter_dim). inter_dim must be a positive multiple of thread_block_size. Token rows are batch * seq_len (flattened row-major [batch, seq_len, …]).

inline int moeDecodeGemvTopkGridDimBatchSeq(
int const batch,
int const seq_len,
int const top_k,
int const strip_dim,
int const thread_block_size
) noexcept#

Same as moeDecodeGemvTopkGridDim with num_tokens = batch * seq_len. Pass strip_dim = hidden_dim for W4A16 up (strips along hidden) or inter_dim for W4A16 down (strips along intermediate).

inline int moeW4a4DecodeUpGridDim(
int const batch,
int const seq_len,
int const top_k,
int const hidden_dim,
int const thread_block_size
) noexcept#

Grid x-dimension for W4A4 decode up kernel: strips along hidden_dim (same as launchNemotronMoeW4A4DecodeUpGemvCuda). num_tokens = batch * seq_len (flattened row-major tokens).

inline int moeDecodeGemvTopkThreads(
int const batch_size,
int const top_k,
int const inter_dim
) noexcept#

Per-intermediate-index MACs for top-k MoE decode GEMV: batch_size * top_k * inter_dim.

inline int64_t nemotronMoeW4A16InterBufferNumElems(
int batch_size,
int top_k,
int inter_dim
) noexcept#

Element count for W4A16/W4A4 split intermediate tensor [batch * seq_len, top_k, inter_dim] (row-major); stored as FP16 between up and down.

inline int64_t nemotronMoeW4A16UpFp16ScratchBytes(
int batch_size,
int top_k,
int inter_dim
) noexcept#

Device scratch bytes: FP16 row-major [batch * seq_len, top_k, inter_dim] (up-proj dot z), passed to down-proj as __half*.

Variables

int kDefaultMlpW4a4DecodeThreadBlockSize = 128#

Default CUDA block size for W4A4 decode: one block size must divide both hidden_dim (up strips) and inter_dim (down strips); see nemotronMoeW4A4DecodeThreadBlockSizeForDims. hidden_dim must be divisible by 64.

int kMaxDecodingKernelWarpCount = 16#

Upper bound on warps per block for MoE decode GEMV shared scratch (accumulateNvfp4GemvTileWarpReduce in marlin_template.cuh). Must be at least 256 / 32 for current largest thread_block_size.