LLM Inference Spec Decode Runtime#

class LLMInferenceSpecDecodeRuntime#

Unified LLM inference runtime with optional Eagle speculative decoding.

Manages inference pipeline for both standard (vanilla) and Eagle speculative decoding modes. When constructed without a drafting config, operates as a pure vanilla decoding runtime (equivalent to the former LLMInferenceRuntime) with zero draft-model memory overhead. Coordinates base model, optional draft model, and multimodal processing (vision + audio).

Public Functions

LLMInferenceSpecDecodeRuntime(
std::string const &engineDir,
std::string const &multimodalEngineDir,
std::unordered_map<std::string, std::string> const &loraWeightsMap,
EagleDraftingConfig const &draftingConfig,
cudaStream_t stream
)#

Construct runtime with Eagle speculative decoding.

Parameters:
  • engineDir – Directory containing engine files

  • multimodalEngineDir – Directory containing multimodal engine files

  • loraWeightsMap – Map of LoRA weight names to file paths

  • draftingConfig – Eagle drafting configuration

  • stream – CUDA stream for operations

Throws:

std::runtime_error – if directories do not contain expected data, or runner initialization fails

LLMInferenceSpecDecodeRuntime(
std::string const &engineDir,
std::string const &multimodalEngineDir,
std::unordered_map<std::string, std::string> const &loraWeightsMap,
cudaStream_t stream
)#

Construct runtime for vanilla-only decoding (no draft model)

Parameters:
  • engineDir – Directory containing engine files

  • multimodalEngineDir – Directory containing multimodal engine files

  • loraWeightsMap – Map of LoRA weight names to file paths

  • stream – CUDA stream for operations

Throws:

std::runtime_error – if directories do not contain expected data, or runner initialization fails

~LLMInferenceSpecDecodeRuntime() noexcept = default#

Destructor.

bool captureDecodingCUDAGraph(cudaStream_t stream)#

Capture CUDA graphs for decoding stages to optimize performance.

When draft model is present, captures graphs for draft proposal, draft accept token, base verification, and base vanilla decoding. Without draft model, captures only vanilla decoding graphs.

Note

If capture fails for any stage, the inference can proceed without CUDA graph capture, but at cost of performance degradation.

Parameters:

stream – CUDA stream

Throws:

std::runtime_error – if a tensor reshape operation fails

Returns:

True if all stage captures succeed, false otherwise

bool handleRequest(
LLMGenerationRequest const &request,
LLMGenerationResponse &response,
cudaStream_t stream
)#

Handle generation request.

Parameters:
  • request – Generation request with prompts and parameters

  • response – Output response with generated tokens and text

  • stream – CUDA stream

Throws:

std::runtime_error – if an LLM or CUDA operation fails

Returns:

True on success, false on failure

bool genAndSaveSystemPromptKVCache(
std::string const &prompt,
std::string const &loraWeightsName,
cudaStream_t stream
)#

Generate and save system prompt KV cache (public API matching standard runtime signature)

Parameters:
  • prompt – The system prompt to generate the KVCache

  • loraWeightsName – The name of the LoRA weights

  • stream – The CUDA stream used for the generation

Throws:

std::runtime_error – if a CUDA operation fails

Returns:

True if the KVCache is generated and saved successfully, false otherwise

inline metrics::LLMPrefillMetrics const &getPrefillMetrics(
) const noexcept#

Get LLM prefill stage metrics.

inline metrics::EagleGenerationMetrics const &getEagleGenerationMetrics(
) const noexcept#

Get Eagle generation stage metrics (only meaningful when draft model is present)

inline metrics::LLMGenerationMetrics const &getGenerationMetrics(
) const noexcept#

Get vanilla generation stage metrics (only meaningful when no draft model / vanilla path)

inline metrics::MultimodalMetrics getMultimodalMetrics(
) const noexcept#

Get multimodal metrics (returns empty metrics if no multimodal runner)

inline bool hasDraftModel() const noexcept#

Check if draft model is loaded and spec-decode is available.

struct SystemPromptKVCache#

Structure to hold cached system prompt and its KV cache (unified with recurrent state support)

Public Members

std::string systemPrompt#

The system prompt text.

std::vector<tokenizer::Rank> tokenizedPrompt#

Tokenized version of the system prompt.

std::vector<rt::Tensor> kvCacheLayers#

Per-layer KV cache tensors for the system prompt.

std::vector<rt::Tensor> recurrentStateContents#

Cached recurrent states for hybrid layers (empty if not applicable)

std::vector<rt::Tensor> convStateContents#

Cached conv states for hybrid layers (empty if not applicable)

struct BatchResult#

Batch result data for a single sequence.

Encapsulates all data needed to track a batch’s execution results, whether it’s active or evicted. Groups related fields together for better cache locality and maintainability.

Public Members

std::vector<int32_t> tokenIds#

Generated token IDs.

std::vector<int32_t> rawBatchedInputIds#

Original input token IDs.

int32_t generateLength = {0}#

Number of tokens generated.

int32_t actualIterations = {0}#

Number of iterations executed.

int32_t effectivePrefillLength = {0}#

Effective prefill length (excluding reused KVCache length)

struct SpecDecodeInferenceContext#

Execution context for speculative decode runtime.

Holds execution information and intermediate metadata during inference. Supports multi-batch inference with independent sequence tracking.

Public Functions

void initialize(
int32_t batchSize,
int32_t maxGenLength,
rt::OptionalInputTensor const &visual,
rt::OptionalInputTensors const &deepstackFeatures,
std::string const &loraName,
cudaStream_t cudaStream
)#

Initialize the context with given parameters.

Parameters:
  • batchSize – Active batch size

  • maxGenLength – Maximum generation length

  • visual – Optional visual embeddings

  • deepstackFeatures – Deepstack features for Qwen3-VL (raw features before embedding)

  • loraName – LoRA weights name used by this request

  • cudaStream – CUDA stream for operations

Public Members

std::vector<std::string> systemPrompts#

System prompts for each sequence in batch.

std::vector<std::vector<int32_t>> rawBatchedInputIds#

Original token IDs before preprocessing (includes padding and removal of reused system IDs)

std::vector<std::vector<int32_t>> tokenIds#

Token IDs for each sequence: [batch_size][seq_length].

std::vector<int32_t> currentGenerateLengths#

Current generation length for each sequence: [batch_size].

std::vector<int32_t> effectivePrefillLengths#

Effective prefill length (excluding reused KVCache length) [batch_size].

std::vector<int8_t> finishedStates#

Finished state for each sequence: [batch_size] (0=not finished, 1=finished)

std::unordered_map<int32_t, BatchResult> completedBatches#

Results of completed batches (unified storage)

std::vector<int32_t> batchIndexMapping#

Maps current batch index to original index.

std::vector<SlotStreamState> slotStreams#

Per-slot streaming state (parallel to tokenIds).

rt::OptionalInputTensor visualEmbeddings#

Optional visual embeddings.

rt::OptionalInputTensor audioEmbeddings#

Optional audio embeddings.

rt::OptionalInputTensors deepstackFeatures#

Deepstack features for Qwen3-VL (raw features before embedding)

int32_t generationRound#

Current generation round (shared across all batches)

int32_t maxGenerateLength#

Maximum generation length.

int32_t activeBatchSize#

Current active batch size.

std::string loraWeightsName = {""}#

LoRA adapter name used by this request.

cudaStream_t stream#

CUDA stream.

float temperature = {1.0f}#

Temperature for sampling.

float topP = {1.0f}#

Top-P (nucleus) sampling parameter.

int64_t topK = {0}#

Top-K sampling parameter.

struct EagleDraftingConfig#

Drafting configuration for Eagle speculative decoding.

Configuration parameters to drive Eagle spec-decoding.

Public Members

int32_t draftingTopK#

Tokens to select from one predecessor for next draft tree level.

int32_t draftingStep#

Number of drafting steps with draft model.

int32_t verifyTreeSize#

Number of tokens for base model verification.