LLM Engine Config#
-
struct LLMEngineConfig#
Unified configuration for base, vanilla decode, and SpecDecode draft engines.
Replaces
LLMEngineRunnerConfigandEagleDraftEngineRunnerConfigwith a single structure that can be parsed once and used across all runtime components (KV cache, RoPE, LoRA, preprocessors, etc.).Public Functions
- InferenceDims prefillDims(
- int64_t batch,
- int64_t seqLen,
- bool kvCacheAllEmpty
Prefill dims (vanilla LLM, SpecDecode base, and SpecDecode draft). seqLen is the prompt length being processed this step. kvCacheAllEmpty signals whether this is the initial prefill of an empty KV cache — for plugin-path engines, this drives the
kvcache_start_indexshape to[0](engine’s “initial prefill” sentinel) instead of[batch]. TRT-native-ops engines always use[batch]regardless.
-
InferenceDims decodeDims(int64_t batch) const#
Vanilla single-token decode dims. seqLen is always 1 here; packedMaskLen is 1 (no proposal mask in vanilla).
-
InferenceDims specVerifyDims(int64_t batch, int64_t verifySize) const#
SpecDecode base verification dims. verifySize feeds three fields: seqLen, selectLen, and packedMaskLen. This is the only recipe where selectLen != 1.
- InferenceDims proposalDims(
- int64_t batch,
- int64_t proposalSize,
- int64_t draftTopK
SpecDecode draft proposal dims. proposalSize feeds seqLen, attnMaskSeqLen, and packedMaskLen. draftTopK feeds selectLen — the draft proposal selects draftTopK tokens per sequence, matching the 3D logits output shape [batch, draftTopK, draftVocabSize].
-
InferenceDims acceptDims(int64_t batch, int64_t acceptLen) const#
SpecDecode draft accept-token dims. acceptLen is in [1, draftingStep+1]; feeds seqLen and packedMaskLen.
-
InferenceDims resetDims() const#
Clean-slate dims for CUDA-graph capture reset. All 1s except kvLen = maxKVCacheCapacity. ropeBatch is fixed at 1 even for MRope models — this matches the pre-migration behavior in both runtimes (reset is a binding placeholder, not an inference step).
Public Members
Model hidden dimension.
-
int32_t outputVocabSize = {}#
Actual output vocab (reduced if vocab reduction active)
-
int32_t numAttentionLayers = {}#
Number of attention layers needing KV cache.
-
int32_t numKVHeads = {}#
Number of key-value heads.
-
int32_t headDim = {}#
Dimension of each attention head.
-
int32_t maxSupportedBatchSize = {}#
Maximum supported batch size.
-
int32_t maxSupportedInputLength = {}#
Maximum supported input length.
-
int32_t maxKVCacheCapacity = {}#
Maximum KV cache capacity (sequence length)
-
int32_t rotaryDim = {}#
Rotary embedding dimension.
-
int32_t numDecoderLayers = {}#
Total decoder layers (attention + linear)
-
int32_t vocabSize = {}#
Full vocabulary size.
-
int32_t reducedVocabSize = {0}#
0 = no vocab reduction
-
bool useTrtNativeOps = {false}#
Use TRT native ops instead of custom plugin.
-
bool isSpecDecodeBase = {false}#
Base engine exposes speculative decoding verification bindings.
-
SpecDecodeMode specDecodeType{SpecDecodeMode::kNONE}#
Speculative decoding strategy mode (parsed from model_type)
-
nvinfer1::DataType kvCacheDtype = {nvinfer1::DataType::kHALF}#
KV cache data type. Parsed from required top-level
kv_cache_dtypeinconfig.json(written byllm_export.py). Accepted values: “fp16” → kHALF, “fp8” → kFP8, “int8” → kINT8, “bf16” → kBF16. The runtime validates this against the engine’s actual KV binding dtype.
-
nvinfer1::DataType recurrentStateDtype = {nvinfer1::DataType::kHALF}#
Recurrent state data type (hybrid models only). Parsed from required top-level
recurrent_state_dtypewhennumLinearAttnLayers > 0; left at the default otherwise. Runtime validates against the engine’s recurrent-state binding dtype.
-
nvinfer1::DataType convStateDtype = {nvinfer1::DataType::kHALF}#
Conv state data type (hybrid models only). Same shape as
recurrentStateDtype.
-
RopeConfig ropeConfig = {}#
Full RoPE configuration.
-
bool useContextDependentRope = {false}#
Use context-dependent RoPE.
-
int32_t numDeepstackFeatures = {0}#
Deepstack features (Qwen3-VL/Qwen3-Omni)
-
int32_t maxSupportedLoraRank = {0}#
Maximum LoRA rank (0 = no LoRA)
-
int32_t imageTokenId = {-1}#
Special token ID for image (-1 = unused)
-
int32_t audioTokenId = {-1}#
Special token ID for audio (-1 = unused)
-
int32_t numLinearAttnLayers = {0}#
Number of linear attention / recurrent layers.
-
int32_t recurrentStateNumHeads = {0}#
Recurrent state heads (hv for GDN, mamba_num_heads for Mamba)
-
int32_t recurrentStateHeadDim = {0}#
Recurrent state head dimension.
-
int32_t recurrentStateSize = {0}#
Recurrent state dimension (v for GDN, dstate for Mamba)
-
int32_t convDim = {0}#
Conv1d channel dimension.
-
int32_t convKernel = {0}#
Conv1d kernel width.
-
int32_t maxVerifyTreeSize = {0}#
Max seq_len the base engine accepts for proposal verification. Parsed from
builder_config.max_verify_tree_sizewhenisSpecDecodeBase == true; 0 otherwise. Consumers prefer the consolidatedDeploymentConfig::specDecodewhen the deployment view is available.
-
int32_t maxDraftTreeSize = {0}#
Max seq_len the draft engine accepts for proposal / draft generation. Parsed from
builder_config.max_draft_tree_sizebyparseDraftEngineConfig; 0 on base / vanilla engines. Consumers prefer the consolidatedDeploymentConfig::specDecodewhen the deployment view is available.
-
int32_t baseModelHiddenSize = {0}#
Hidden dim the draft engine expects for its
hidden_states_inputbinding (== the base engine’s hidden-state output dim as seen by the draft). Parsed from top-levelbase_model_hidden_sizebyparseDraftEngineConfig; left at 0 on base / vanilla engines. Differs frombase.hiddenSizefor EAGLE-3 (= base.hiddenSize * 3, multi-layer concat) and equalsbase.hiddenSizefor MTP. The deployment factory copies this intoDeploymentConfig::specDecode->baseOutputHiddenDim.
-
std::vector<HybridCacheManager::LayerType> layerTypes = {}#
Absolute decoder-layer -> attention|mamba. Populated either from the canonical
layer_typesfield in config.json or by broadcasting scalarnumAttentionLayers/numLinearAttnLayersfor back-compat. Size equals the number of stateful decoder layers (mlp/moe excluded on the Python side).
-
std::vector<KVLayerConfig> kvLayerConfigs = {}#
Per-attention-layer KV config. Size equals the attention count in
layerTypes. Indexed by LOCAL attention-layer index (0..numAttn-1), NOT absolute decoder-layer index.
- LLMEngineConfig trt_edgellm::rt::parseEngineConfig(
- std::filesystem::path const &configPath
Parse a
config.jsonfile (the same format used by the existing runtime) into anLLMEngineConfig.- Parameters:
configPath – Path to
config.json.- Throws:
std::runtime_error – if file cannot be opened/parsed or required fields are missing.
- Returns:
Parsed configuration.
- LLMEngineConfig trt_edgellm::rt::parseDraftEngineConfig(
- std::filesystem::path const &configPath
Parse a SpecDecode draft engine’s
config.jsoninto anLLMEngineConfig.The draft config carries a reduced field set (no
builder_config.eagle_base, its owndraft_vocab_size).max_draft_tree_sizeis required and is parsed intocfg.maxDraftTreeSize;cfg.maxVerifyTreeSizestays at 0 on the draft side.isSpecDecodeBaseis left false because this is the draft — not the base — engine. Cross-engine fields (draftHiddenSize,baseOutputHiddenDim) are not stored onLLMEngineConfig; they are derived increateDeploymentConfigand live onDeploymentConfig::specConfig.- Parameters:
configPath – Path to the draft engine’s
config.json.- Throws:
std::runtime_error – on parse failure or missing required fields.
- Returns:
Parsed configuration.
- std::string trt_edgellm::rt::formatEngineConfig(
- LLMEngineConfig const &config
Format the config as a human-readable string (for logging).
- void trt_edgellm::rt::validateAgainstEngine(
- LLMEngineConfig const &config,
- EngineExecutor const &executor,
- char const *engineLabel
Cross-check an engine’s KV / recurrent / conv binding dtypes against their parsed-config counterparts. The parsed config is the source of truth; this raises a runtime_error with an actionable message on any mismatch.
- Parameters:
config – Parsed
LLMEngineConfig(fromparseEngineConfigor similar).executor – Engine whose binding dtypes should match.
engineLabel – Short label (“base” / “draft”) used in error messages.
- Throws:
std::runtime_error – on any KV / recurrent / conv dtype mismatch.