/*
 * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "llmInferenceSpecDecodeRuntime.h"
#include "common/bindingNames.h"
#include "common/checkMacros.h"
#include "common/cudaUtils.h"
#include "common/logger.h"
#include "common/mathUtils.h"
#include "common/safetensorsUtils.h"
#include "kernels/embeddingKernels/embeddingKernels.h"
#include "kernels/posEncoding/initializeCosSinCache.h"
#include "kernels/speculative/batchEvictKernels.h"
#include "kernels/speculative/eagleAcceptKernels.h"
#include "kernels/speculative/eagleUtilKernels.h"
#include "multimodal/multimodalRunner.h"
#include "profiling/nvtx_wrapper.h"
#include "profiling/timer.h"
#include "runtime/hybridCacheManager.h"
#include "runtime/llmRuntimeUtils.h"
#include "sampler/sampling.h"
#include <algorithm>
#include <cmath>
#include <filesystem>
#include <fstream>
#include <functional>
#include <optional>
#include <string>
#include <vector>

using namespace nvinfer1;

namespace trt_edgellm
{
namespace
{
std::tuple<std::string, std::string> keySystemPromptWithLoraWeights(
    std::string const& systemPrompt, std::string const& loraWeightsName)
{
    return std::make_tuple(systemPrompt, loraWeightsName);
}

} // namespace

namespace rt
{

void SpecDecodeInferenceContext::initialize(int32_t _activeBatchSize, int32_t _maxGenerateLength,
    rt::OptionalInputTensor const& _visualEmbeddings, rt::OptionalInputTensors const& _deepstackFeatures,
    std::string const& _loraWeightsName, cudaStream_t _stream)
{
    systemPrompts.resize(_activeBatchSize);
    rawBatchedInputIds.reserve(_activeBatchSize);
    tokenIds.resize(_activeBatchSize);
    currentGenerateLengths.resize(_activeBatchSize, 0);
    effectivePrefillLengths.resize(_activeBatchSize, 0);
    finishedStates.resize(_activeBatchSize, 0);
    slotStreams.clear();
    slotStreams.resize(_activeBatchSize);

    // Initialize batch index mapping (identity mapping initially)
    batchIndexMapping.resize(_activeBatchSize);
    for (int32_t i = 0; i < _activeBatchSize; ++i)
    {
        batchIndexMapping[i] = i;
    }

    // Clear completed batch storage
    completedBatches.clear();

    visualEmbeddings = _visualEmbeddings;
    deepstackFeatures = _deepstackFeatures;
    generationRound = 0;
    maxGenerateLength = _maxGenerateLength;
    activeBatchSize = _activeBatchSize;
    loraWeightsName = _loraWeightsName;
    stream = _stream;
}

LLMInferenceSpecDecodeRuntime::LLMInferenceSpecDecodeRuntime(std::string const& engineDir,
    std::string const& multimodalEngineDir, std::unordered_map<std::string, std::string> const& loraWeightsMap,
    EagleDraftingConfig const& draftingConfig, cudaStream_t stream)
{
    initializeCommon(engineDir, multimodalEngineDir, loraWeightsMap, draftingConfig, stream);
}

LLMInferenceSpecDecodeRuntime::LLMInferenceSpecDecodeRuntime(std::string const& engineDir,
    std::string const& multimodalEngineDir, std::unordered_map<std::string, std::string> const& loraWeightsMap,
    cudaStream_t stream)
{
    initializeCommon(engineDir, multimodalEngineDir, loraWeightsMap, std::nullopt, stream);
}

void LLMInferenceSpecDecodeRuntime::initializeCommon(std::string const& engineDir,
    std::string const& multimodalEngineDir, std::unordered_map<std::string, std::string> const& loraWeightsMap,
    std::optional<EagleDraftingConfig> const& draftingConfig, cudaStream_t stream)
{
    mDraftingConfig = draftingConfig;

    // Load shared embedding table from embedding.safetensors (shared between base and draft models)
    std::filesystem::path const embeddingPath = std::filesystem::path(engineDir) / "embedding.safetensors";
    mEmbedding = loadEmbeddingTable(embeddingPath, stream);

    // Helper: load an LLMEngineRunner with uniform error handling
    auto loadBaseEngine = [&](std::filesystem::path const& enginePath, std::filesystem::path const& configPath,
                              char const* label) -> std::unique_ptr<LLMEngineRunner> {
        try
        {
            auto runner = std::make_unique<LLMEngineRunner>(enginePath, configPath, loraWeightsMap, stream);
            LOG_INFO("LLMEngineRunner successfully loaded and initialized %s.", label);
            return runner;
        }
        catch (std::exception const& e)
        {
            LOG_ERROR("Failed to initialize LLMEngineRunner (%s): %s", label, e.what());
            throw std::runtime_error("Failed to initialize LLMEngineRunner: " + std::string(e.what()));
        }
    };

    if (draftingConfig.has_value())
    {
        // Eagle speculative decoding mode: use named engine files
        mBaseEngineRunner = loadBaseEngine(std::filesystem::path(engineDir) / "eagle_base.engine",
            std::filesystem::path(engineDir) / "base_config.json", "eagle base engine");
        mBaseEngineConfig = mBaseEngineRunner->getEngineConfig();

        std::filesystem::path const draftEnginePath = std::filesystem::path(engineDir) / "eagle_draft.engine";
        std::filesystem::path const draftConfigPath = std::filesystem::path(engineDir) / "draft_config.json";
        try
        {
            mDraftEngineRunner = std::make_unique<EagleDraftEngineRunner>(draftEnginePath, draftConfigPath, stream);
        }
        catch (std::exception const& e)
        {
            LOG_ERROR("Failed to initialize EagleDraftEngineRunner: %s", e.what());
            throw std::runtime_error("Failed to initialize EagleDraftEngineRunner: " + std::string(e.what()));
        }
        LOG_INFO("EagleDraftEngineRunner successfully initialized.");
        mDraftEngineConfig = mDraftEngineRunner->getDraftEngineConfig();

        // Set runtime batch size to the minimum of base and draft engine's maxSupportedBatchSize
        // This ensures CUDA graph capture and tensor allocation work for both engines
        mMaxRuntimeBatchSize
            = std::min(mBaseEngineConfig.maxSupportedBatchSize, mDraftEngineConfig->maxSupportedBatchSize);
        LOG_INFO("Runtime batch size set to: %d (base engine max: %d, draft engine max: %d)", mMaxRuntimeBatchSize,
            mBaseEngineConfig.maxSupportedBatchSize, mDraftEngineConfig->maxSupportedBatchSize);

        // Validate drafting configuration against engine capabilities
        // maxDraftTreeSize controls the maximum input length to draft proposal step
        int32_t const requiredDraftInputSize = mDraftingConfig->draftingStep * mDraftingConfig->draftingTopK;
        if (requiredDraftInputSize > mDraftEngineConfig->maxDraftTreeSize)
        {
            LOG_ERROR(
                "Drafting config requires %d draft input tokens (draftingStep=%d * draftingTopK=%d) but engine "
                "supports "
                "max %d draft tree size",
                requiredDraftInputSize, mDraftingConfig->draftingStep, mDraftingConfig->draftingTopK,
                mDraftEngineConfig->maxDraftTreeSize);
            throw std::runtime_error("Drafting configuration exceeds engine draft tree size capability");
        }

        // Validate that verifyTreeSize doesn't exceed the maximum draft tree size
        if (mDraftingConfig->verifyTreeSize > mDraftEngineConfig->maxDraftTreeSize)
        {
            LOG_ERROR("Drafting config verifyTreeSize (%d) exceeds engine maxDraftTreeSize (%d)",
                mDraftingConfig->verifyTreeSize, mDraftEngineConfig->maxDraftTreeSize);
            throw std::runtime_error("Verify tree size exceeds engine maximum draft tree size");
        }
    }
    else
    {
        // Vanilla-only mode: use hardcoded engine filename matching old LLMInferenceRuntime
        mBaseEngineRunner = loadBaseEngine(std::filesystem::path(engineDir) / "llm.engine",
            std::filesystem::path(engineDir) / "config.json", "base engine (vanilla mode)");
        mBaseEngineConfig = mBaseEngineRunner->getEngineConfig();

        // No draft engine in vanilla mode
        mDraftEngineRunner = nullptr;
        mDraftEngineConfig = std::nullopt;

        mMaxRuntimeBatchSize = mBaseEngineConfig.maxSupportedBatchSize;
        LOG_INFO("Runtime batch size set to: %d (vanilla mode, base engine max)", mMaxRuntimeBatchSize);
    }

    if (mBaseEngineConfig.numDeepstackFeatures > 0 && multimodalEngineDir.empty())
    {
        throw std::runtime_error("--multimodalEngineDir is required for VLM engine.");
    }

    // Allocate runtime tensors till max supported size.
    bool const hasDraft = (mDraftEngineRunner != nullptr);
    int32_t const effectiveMaxTreeSize
        = hasDraft ? std::max(mDraftEngineConfig->maxDraftTreeSize, mDraftingConfig->verifyTreeSize) : 1;
    int32_t const effectiveDraftTopK = hasDraft ? mDraftingConfig->draftingTopK : 1;
    int32_t const effectiveMaxAcceptDepth = hasDraft ? mDraftingConfig->draftingStep + 1 : 1;

    int32_t const maxLogitsSize = mMaxRuntimeBatchSize * effectiveMaxTreeSize;
    int32_t const maxVocabSize = hasDraft
        ? std::max(mBaseEngineConfig.outputVocabSize, mDraftEngineConfig->draftModelVocabSize)
        : mBaseEngineConfig.outputVocabSize;
    int32_t const maxSamplingSize = hasDraft ? std::max(mMaxRuntimeBatchSize * effectiveMaxTreeSize,
                                                   mMaxRuntimeBatchSize * effectiveDraftTopK * effectiveDraftTopK)
                                             : mMaxRuntimeBatchSize;

    int32_t const draftFullTableLength = hasDraft
        ? 1 + effectiveDraftTopK + (mDraftingConfig->draftingStep - 1) * effectiveDraftTopK * effectiveDraftTopK
        : 0;

    LOG_DEBUG(
        "effectiveMaxTreeSize: %d, maxSamplingSize: %d, draftFullTableLength: %d to set up the SpecDecode inference "
        "runtime",
        effectiveMaxTreeSize, maxSamplingSize, draftFullTableLength);

    // Reserve enough workspace for sampling, accounting for batch dimension in draft proposal stage
    int32_t const maxSamplingWorkspaceSize = hasDraft
        ? std::max(getSelectAllTopKWorkspaceSize(mMaxRuntimeBatchSize, mBaseEngineConfig.outputVocabSize, 1),
              getSelectAllTopKWorkspaceSize(mMaxRuntimeBatchSize * effectiveDraftTopK,
                  mDraftEngineConfig->draftModelVocabSize, effectiveDraftTopK))
        : static_cast<int32_t>(getTopKtopPSamplingWorkspaceSize(mMaxRuntimeBatchSize, mBaseEngineConfig.outputVocabSize,
              SamplingParams(mMaxRuntimeBatchSize, mBaseEngineConfig.outputVocabSize, 1.0f, 0, 0.9f)));

    try
    {
        mIdsInput = rt::Tensor({mMaxRuntimeBatchSize, mBaseEngineConfig.maxSupportedInputLength}, rt::DeviceType::kGPU,
            DataType::kINT32, "LLMInferenceSpecDecodeRuntime::mIdsInput");
        mInputsEmbeds = rt::Tensor(
            {mMaxRuntimeBatchSize, mBaseEngineConfig.maxSupportedInputLength, mBaseEngineConfig.hiddenSize},
            rt::DeviceType::kGPU, DataType::kHALF, "LLMInferenceSpecDecodeRuntime::mInputsEmbeds");
        // Allocate deepstack embeddings if needed (one tensor per feature)
        if (mBaseEngineConfig.numDeepstackFeatures > 0)
        {
            mDeepstackEmbeds.resize(mBaseEngineConfig.numDeepstackFeatures);
            for (int32_t i = 0; i < mBaseEngineConfig.numDeepstackFeatures; ++i)
            {
                mDeepstackEmbeds[i] = rt::Tensor(
                    {mMaxRuntimeBatchSize, mBaseEngineConfig.maxSupportedInputLength, mBaseEngineConfig.hiddenSize},
                    rt::DeviceType::kGPU, DataType::kHALF,
                    format::fmtstr("LLMInferenceSpecDecodeRuntime::mDeepstackEmbeds[%d]", i));
            }
            LOG_INFO("Allocated %d deepstack embeds tensors with shape [%d, %d, %d]",
                mBaseEngineConfig.numDeepstackFeatures, mMaxRuntimeBatchSize, mBaseEngineConfig.maxSupportedInputLength,
                mBaseEngineConfig.hiddenSize);
        }
        mContextLengthsInput = rt::Tensor({mMaxRuntimeBatchSize}, rt::DeviceType::kCPU, DataType::kINT32,
            "LLMInferenceSpecDecodeRuntime::mContextLengthsInput");
        // Allocate mLogitsOutput with max capacity to support both draft (smaller vocab) and base (larger vocab)
        // operations Max size needed: batch_size * verify_tree_size * base_vocab_size for base verification
        mLogitsOutput = rt::Tensor({maxLogitsSize, maxVocabSize}, rt::DeviceType::kGPU, DataType::kFLOAT,
            "LLMInferenceSpecDecodeRuntime::mLogitsOutput");

        // Draft-exclusive tensors (only allocated when draft model present)
        if (hasDraft)
        {
            // Base hidden states are consumed by the draft model — only needed for spec-decode.
            // outputHiddenDim is only set when enableEagleSpecDecode is true in the engine config.
            mBaseHiddenStatesOutput = rt::Tensor(
                {mMaxRuntimeBatchSize, mBaseEngineConfig.maxSupportedInputLength, mBaseEngineConfig.outputHiddenDim},
                rt::DeviceType::kGPU, DataType::kHALF, "LLMInferenceSpecDecodeRuntime::mBaseHiddenStatesOutput");
            mDraftTreeSize = rt::Tensor({mMaxRuntimeBatchSize}, rt::DeviceType::kGPU, DataType::kINT32,
                "LLMInferenceSpecDecodeRuntime::mDraftTreeSize");
            mDraftTreeMask = rt::Tensor({mMaxRuntimeBatchSize, effectiveMaxTreeSize, effectiveMaxTreeSize},
                rt::DeviceType::kGPU, DataType::kINT8, "LLMInferenceSpecDecodeRuntime::mDraftTreeMask");
            mDraftHiddenStatesInput = rt::Tensor({mMaxRuntimeBatchSize, mBaseEngineConfig.maxSupportedInputLength,
                                                     mDraftEngineConfig->draftModelHiddenDim},
                rt::DeviceType::kGPU, DataType::kHALF, "LLMInferenceSpecDecodeRuntime::mDraftHiddenStatesInput");
            mDraftHiddenStatesOutput
                = rt::Tensor({mMaxRuntimeBatchSize, effectiveDraftTopK, mDraftEngineConfig->draftModelHiddenDim},
                    rt::DeviceType::kGPU, DataType::kHALF, "LLMInferenceSpecDecodeRuntime::mDraftHiddenStatesOutput");
            mDraftTokenIdsFullTable = rt::Tensor({mMaxRuntimeBatchSize, draftFullTableLength}, rt::DeviceType::kGPU,
                DataType::kINT32, "LLMInferenceSpecDecodeRuntime::mDraftTokenIdsFullTable");
            mDraftTokenScoreFullTable = rt::Tensor({mMaxRuntimeBatchSize, draftFullTableLength}, rt::DeviceType::kGPU,
                DataType::kFLOAT, "LLMInferenceSpecDecodeRuntime::mDraftTokenScoreFullTable");
            mDraftTokenPredecessorFullTable
                = rt::Tensor({mMaxRuntimeBatchSize, draftFullTableLength}, rt::DeviceType::kGPU, DataType::kINT32,
                    "LLMInferenceSpecDecodeRuntime::mDraftTokenPredecessorFullTable");
            // Draft vocab mapping table is 1D and shared across all batches (not batch-dependent)
            mDraftVocabMappingTable = rt::Tensor({mDraftEngineConfig->draftModelVocabSize}, rt::DeviceType::kGPU,
                DataType::kINT32, "LLMInferenceSpecDecodeRuntime::mDraftVocabMappingTable");
            mDraftTreeRootTokenId = rt::Tensor({mMaxRuntimeBatchSize}, rt::DeviceType::kGPU, DataType::kINT32,
                "LLMInferenceSpecDecodeRuntime::mDraftTreeRootTokenId");
            mDraftTokenIdsTable = rt::Tensor({mMaxRuntimeBatchSize, effectiveDraftTopK * effectiveDraftTopK},
                rt::DeviceType::kGPU, DataType::kINT32, "LLMInferenceSpecDecodeRuntime::mDraftTokenIdsTable");
            mDraftTokenScoresTable = rt::Tensor({mMaxRuntimeBatchSize, effectiveDraftTopK * effectiveDraftTopK},
                rt::DeviceType::kGPU, DataType::kFLOAT, "LLMInferenceSpecDecodeRuntime::mDraftTokenScoresTable");
            mDraftTokenIntermediateScores = rt::Tensor({mMaxRuntimeBatchSize, effectiveDraftTopK}, rt::DeviceType::kGPU,
                DataType::kFLOAT, "LLMInferenceSpecDecodeRuntime::mDraftTokenIntermediateScores");
            mDraftTokenIntermediateParents
                = rt::Tensor({mMaxRuntimeBatchSize, effectiveDraftTopK}, rt::DeviceType::kGPU, DataType::kINT32,
                    "LLMInferenceSpecDecodeRuntime::mDraftTokenIntermediateParents");
            mAcceptedTokenIds = rt::Tensor({mMaxRuntimeBatchSize, effectiveMaxAcceptDepth}, rt::DeviceType::kGPU,
                DataType::kINT32, "LLMInferenceSpecDecodeRuntime::mAcceptedTokenIds");
            mAcceptedTokenIndices = rt::Tensor({mMaxRuntimeBatchSize, effectiveMaxAcceptDepth}, rt::DeviceType::kGPU,
                DataType::kINT32, "LLMInferenceSpecDecodeRuntime::mAcceptedTokenIndices");
            mAcceptLength = rt::Tensor({mMaxRuntimeBatchSize}, rt::DeviceType::kGPU, DataType::kINT32,
                "LLMInferenceSpecDecodeRuntime::mAcceptLength");
        }

        mSamplingWorkspace = rt::Tensor({maxSamplingWorkspaceSize}, rt::DeviceType::kGPU, DataType::kINT8,
            "LLMInferenceSpecDecodeRuntime::mSamplingWorkspace");
        mSamplingIndices = rt::Tensor({maxSamplingSize}, rt::DeviceType::kGPU, DataType::kINT32,
            "LLMInferenceSpecDecodeRuntime::mSamplingIndices");
        mSamplingScores = rt::Tensor({maxSamplingSize}, rt::DeviceType::kGPU, DataType::kFLOAT,
            "LLMInferenceSpecDecodeRuntime::mSamplingScores");

        // Allocate batch mapping tensor for batch eviction
        mDeviceBatchMapping = rt::Tensor({mMaxRuntimeBatchSize}, rt::DeviceType::kGPU, DataType::kINT32,
            "LLMInferenceSpecDecodeRuntime::mDeviceBatchMapping");

        mHostPackedTokenIds = rt::Tensor({mMaxRuntimeBatchSize, mBaseEngineConfig.maxSupportedInputLength},
            rt::DeviceType::kCPU, DataType::kINT32, "LLMInferenceSpecDecodeRuntime::mHostPackedTokenIds");
        mHostSelectedTokenIds = rt::Tensor({mMaxRuntimeBatchSize}, rt::DeviceType::kCPU, DataType::kINT32,
            "LLMInferenceSpecDecodeRuntime::mHostSelectedTokenIds");
        mHostReuseKVCacheLengths = rt::Tensor({mMaxRuntimeBatchSize}, rt::DeviceType::kCPU, DataType::kINT32,
            "LLMInferenceSpecDecodeRuntime::mHostReuseKVCacheLengths");

        // Pre-allocate multimodal indices tensor (used for audio/vision embedding lookup)
        mMultimodalIndices = rt::Tensor({mMaxRuntimeBatchSize, mBaseEngineConfig.maxSupportedInputLength},
            rt::DeviceType::kGPU, DataType::kINT32, "LLMInferenceSpecDecodeRuntime::mMultimodalIndices");

        // Draft-exclusive host tensors
        if (hasDraft)
        {
            mHostAcceptLengths = rt::Tensor({mMaxRuntimeBatchSize}, rt::DeviceType::kCPU, DataType::kINT32,
                "LLMInferenceSpecDecodeRuntime::mHostAcceptLengths");
            mHostAcceptedTokenIds = rt::Tensor({mMaxRuntimeBatchSize, effectiveMaxAcceptDepth}, rt::DeviceType::kCPU,
                DataType::kINT32, "LLMInferenceSpecDecodeRuntime::mHostAcceptedTokenIds");
        }
    }
    catch (std::exception const& e)
    {
        LOG_ERROR("Failed to allocate runtime tensors: %s", e.what());
        throw std::runtime_error("Failed to allocate runtime tensors: " + std::string(e.what()));
    }
    LOG_INFO("Runtime tensors successfully allocated.");

    // Load conversion table from draft model vocab to base model vocab (draft-only).
    if (hasDraft)
    {
        std::vector<rt::Tensor> d2tTensors;
        if (!safetensors::loadSafetensors(std::filesystem::path(engineDir) / "d2t.safetensors", d2tTensors, stream))
        {
            LOG_ERROR("Failed to load d2t.safetensors from model directory: %s", engineDir.c_str());
            throw std::runtime_error("Failed to load d2t.safetensors from model directory: " + engineDir);
        }

        // Check we have exactly one tensor and use it
        check::check(d2tTensors.size() == 1, "d2t.safetensors should contain exactly one tensor");
        check::check(d2tTensors[0].getShape().getNumDims() == 1, "d2t tensor should be 1D");
        check::check(d2tTensors[0].getShape()[0] == mDraftEngineConfig->draftModelVocabSize,
            "d2t tensor length should match draft vocab size");
        mDraftVocabMappingTable = std::move(d2tTensors[0]);
    }

    // Optional: Load vocabulary mapping table if base model uses reduced vocabulary
    if (mBaseEngineConfig.reducedVocabSize > 0)
    {
        LOG_INFO("Loading vocabulary mapping table for base model reduced vocab size: %d -> %d",
            mBaseEngineConfig.reducedVocabSize, mBaseEngineConfig.vocabSize);
        std::filesystem::path const vocabMapPath = std::filesystem::path(engineDir) / binding_names::kVocabMapFileName;

        std::vector<rt::Tensor> vocabMapTensors;
        if (!safetensors::loadSafetensors(vocabMapPath, vocabMapTensors, stream))
        {
            LOG_ERROR(
                "Failed to load %s from model directory: %s", binding_names::kVocabMapFileName, engineDir.c_str());
            throw std::runtime_error("Failed to load " + std::string(binding_names::kVocabMapFileName)
                + " from model directory: " + engineDir);
        }

        // Check we have exactly one tensor and use it
        check::check(vocabMapTensors.size() == 1,
            std::string(binding_names::kVocabMapFileName) + " should contain exactly one tensor");
        check::check(vocabMapTensors[0].getShape().getNumDims() == 1, "vocab_map tensor should be 1D");
        check::check(vocabMapTensors[0].getShape()[0] == mBaseEngineConfig.reducedVocabSize,
            "vocab_map tensor length should match base model reduced vocab size");
        mBaseVocabMappingTable = std::move(vocabMapTensors[0]);
        LOG_INFO("Base model vocabulary mapping table successfully loaded.");
    }

    mTokenizer = std::make_unique<tokenizer::Tokenizer>();
    LOG_INFO("Start loading tokenizer from model directory: %s", engineDir.c_str());
    if (!mTokenizer->loadFromHF(engineDir))
    {
        LOG_ERROR("Failed to load tokenizer from model directory: %s", engineDir.c_str());
        throw std::runtime_error("Failed to load tokenizer from model directory: " + engineDir);
    }
    LOG_INFO("Tokenizer successfully loaded from model directory: %s", engineDir.c_str());

    // Optional: Setup multimodal engine runners
    if (!multimodalEngineDir.empty())
    {
        // Helper lambda to try loading a runner from a directory
        auto tryLoadRunner = [&](std::string const& dir, std::string const& name) -> std::unique_ptr<MultimodalRunner> {
            try
            {
                LOG_DEBUG("Attempting to load %s runner from %s", name.c_str(), dir.c_str());
                auto runner = MultimodalRunner::create(
                    dir, mBaseEngineConfig.maxSupportedBatchSize, mBaseEngineConfig.maxKVCacheCapacity, stream);
                LOG_INFO("%s runner successfully initialized", name.c_str());
                return runner;
            }
            catch (std::exception const& e)
            {
                LOG_DEBUG("Failed to load %s runner from %s: %s", name.c_str(), dir.c_str(), e.what());
                return nullptr;
            }
        };

        // Try to load audio runner from multimodalEngineDir/audio
        mAudioRunner = tryLoadRunner(multimodalEngineDir + "/audio", "Audio");

        // Try to load visual runner (with fallback to root directory for backward compatibility)
        mVisionRunner = tryLoadRunner(multimodalEngineDir + "/visual", "Visual");
        if (!mVisionRunner)
        {
            mVisionRunner = tryLoadRunner(multimodalEngineDir, "Vision");
        }

        // At least one multimodal runner must be available
        if (!mAudioRunner && !mVisionRunner)
        {
            throw std::runtime_error("No valid multimodal engine found in " + multimodalEngineDir);
        }
    }

    // Setup shared execution context memory for all engines (base, draft, and optionally VIT).
    // All engines execute serially (not concurrently), so they can share a single buffer
    // sized to the maximum requirement among all engines.
    int64_t const baseContextMemorySize = mBaseEngineRunner->getRequiredContextMemorySize();
    int64_t const draftContextMemorySize = mDraftEngineRunner ? mDraftEngineRunner->getRequiredContextMemorySize() : 0;
    int64_t const visionContextMemorySize = mVisionRunner ? mVisionRunner->getRequiredContextMemorySize() : 0;
    int64_t const audioContextMemorySize = mAudioRunner ? mAudioRunner->getRequiredContextMemorySize() : 0;
    int64_t const sharedContextMemorySize
        = std::max({baseContextMemorySize, draftContextMemorySize, visionContextMemorySize, audioContextMemorySize});
    mSharedExecContextMemory = rt::Tensor({sharedContextMemorySize}, rt::DeviceType::kGPU, nvinfer1::DataType::kUINT8,
        "LLMInferenceSpecDecodeRuntime::mSharedExecContextMemory");
    mBaseEngineRunner->setContextMemory(mSharedExecContextMemory);
    if (mDraftEngineRunner)
    {
        mDraftEngineRunner->setContextMemory(mSharedExecContextMemory);
    }
    if (mVisionRunner)
    {
        mVisionRunner->setContextMemory(mSharedExecContextMemory);
    }
    if (mAudioRunner)
    {
        mAudioRunner->setContextMemory(mSharedExecContextMemory);
    }
    LOG_INFO(
        "Setup shared execution context memory: %zu bytes (base requires: %zu, draft requires: %zu, vision requires: "
        "%zu, audio requires: %zu)",
        static_cast<size_t>(sharedContextMemorySize), static_cast<size_t>(baseContextMemorySize),
        static_cast<size_t>(draftContextMemorySize), static_cast<size_t>(visionContextMemorySize),
        static_cast<size_t>(audioContextMemorySize));
}

bool LLMInferenceSpecDecodeRuntime::handleRequest(
    LLMGenerationRequest const& request, LLMGenerationResponse& response, cudaStream_t stream)
{
    int32_t const activeBatchSize = static_cast<int32_t>(request.requests.size());
    bool const enableSpecDecode = (mDraftEngineRunner != nullptr) && !request.disableSpecDecode;
    std::string const& loraWeightsName = request.loraWeightsName;

    if (!validateRequestConfig(request))
    {
        return false;
    }

    if (!validateStreamingSubmission(request))
    {
        return false;
    }

    // Speculative decoding only supports greedy; override non-default sampling params.
    bool const hasNonDefaultSampling
        = (request.topK > 1 || request.topP < 1.0f || std::fabs(request.temperature - 1.0f) > 1e-3f);
    if (enableSpecDecode && hasNonDefaultSampling)
    {
        LOG_WARNING("Spec-decode active: overriding sampling params to greedy (ignoring temp/topK/topP).");
    }

    int32_t maxGenerateLength = request.maxGenerateLength;

    // Apply chat template for all requests (common for both multimodal and non-multimodal)
    request.formattedRequests.resize(activeBatchSize);
    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        // Apply chat template to populate both formatted system prompt and full formatted prompt
        mTokenizer->applyChatTemplate(request.requests[i], request.formattedRequests[i], request.applyChatTemplate,
            request.addGenerationPrompt, request.enableThinking);
    }

    SpecDecodeInferenceContext context;
    context.initialize(
        activeBatchSize, maxGenerateLength, std::nullopt, rt::OptionalInputTensors{}, loraWeightsName, stream);
    bool const supportsMultimodalInput = (mAudioRunner != nullptr) || (mVisionRunner != nullptr);

    if (supportsMultimodalInput)
    {
        if (!multiModalRuntimePreprocess(request, context, stream))
        {
            return false;
        }
    }
    else
    {
        for (int32_t i = 0; i < activeBatchSize; ++i)
        {
            context.systemPrompts[i] = request.formattedRequests[i].formattedSystemPrompt;
            context.rawBatchedInputIds.emplace_back(
                mTokenizer->encode(request.formattedRequests[i].formattedCompleteRequest, false));
            if (context.rawBatchedInputIds[i].empty())
            {
                LOG_ERROR("Failed to tokenize input text for request %d in batch", i);
                return false;
            }
        }
    }

    // Forward sampling params to context; spec-decode forces greedy.
    context.temperature = enableSpecDecode ? 1.0f : request.temperature;
    context.topP = enableSpecDecode ? 1.0f : request.topP;
    context.topK = enableSpecDecode ? 0 : request.topK;

    // The spec-decode path needs extra KV reserve for draft tokens during verification.
    constexpr int32_t kDRAFT_KVCACHE_RESERVE_LENGTH{100};
    int32_t const kvCacheCapacity = enableSpecDecode
        ? std::max(mBaseEngineConfig.maxKVCacheCapacity, mDraftEngineConfig->maxKVCacheCapacity)
        : mBaseEngineConfig.maxKVCacheCapacity;
    int32_t const kvcReserve = enableSpecDecode ? kDRAFT_KVCACHE_RESERVE_LENGTH : 0;

    // In production, the system-prompt KV cache is saved during warm-up.
    // We disable profiling here to make benchmarking closer to production inference result.
    bool profilingEnabled = getProfilingEnabled();
    if (profilingEnabled)
    {
        setProfilingEnabled(false);
    }

    // Generate system prompt KVCache for each sequence in the batch
    if (request.saveSystemPromptKVCache)
    {
        for (int32_t i = 0; i < activeBatchSize; ++i)
        {
            bool const saveCacheStatus = genAndSaveSystemPromptKVCache(context, i);
            if (!saveCacheStatus)
            {
                LOG_WARNING(
                    "Failed to save system prompt KVCache for request %d in batch. "
                    "Continue to handle the request without saving the system prompt KVCache.",
                    i);
            }
        }
    }

    if (profilingEnabled)
    {
        setProfilingEnabled(true);
    }

    // Conduct the preparation work to handle a new set of sequences, including inputIds packing, input/output tensor
    // preparation, reset the KVCache state, and apply reused prefix KVCache if available.
    if (!setUpForPrefillExecution(context))
    {
        LOG_ERROR("Prefill execution setup failed. This request cannot be handled.");
        return false;
    }

    // ── Streaming setup ──────────────────────────────────────────────────────
    // Attach first, record in slotStreams only on success — a throw from attach
    // keeps foreign channels out of the finalizer's reach. Seed sentTokenCount
    // to the prompt length so streaming emits only generated tokens.
    for (int32_t i = 0; i < context.activeBatchSize; ++i)
    {
        if (request.streamChannels.empty() || !request.streamChannels[i])
        {
            continue;
        }
        attachStreamChannel(request.streamChannels[i], context.batchIndexMapping[i]);
        auto& slot = context.slotStreams[i];
        slot.channel = request.streamChannels[i];
        slot.sentTokenCount = context.tokenIds[i].size();
        slot.lastEmittedTokenCount = slot.sentTokenCount;
    }
    StreamChannelFinalizer streamFinalizer(context, *mTokenizer);

    int32_t const clampedMaxGenerateLength = clampMaxGenerateLengthForKVCapacity(
        context.effectivePrefillLengths, request.maxGenerateLength, kvCacheCapacity, kvcReserve);
    if (clampedMaxGenerateLength != context.maxGenerateLength)
    {
        context.maxGenerateLength = clampedMaxGenerateLength;
        LOG_WARNING("Reduce max generation length to %d", context.maxGenerateLength);
    }
    if (context.maxGenerateLength <= 0)
    {
        LOG_ERROR("Insufficient KV cache capacity for generation for this request.");
        return false;
    }

    // Prefill from the base model and run spec-decode inference.
    bool const prefillStatus = runBaseModelPrefill(context);
    if (!prefillStatus)
    {
        LOG_ERROR("Failed to execute prefill step for base model.");
        return false;
    }

    // Lambda to check if all batches are finished
    auto checkAllFinished = [&]() {
        // Check if all batches have been evicted
        if (context.activeBatchSize == 0)
        {
            return true;
        }
        for (int32_t i = 0; i < context.activeBatchSize; ++i)
        {
            if (!context.finishedStates[i])
            {
                return false;
            }
        }
        return true;
    };

    // Lambda to update finish states based on EOS and max_length. Latches
    // terminalReason atomically with the state flip — the !finishedStates guard
    // keeps first-writer-wins semantics relative to applyCancellationToFinishStates.
    auto updateFinishStates = [&]() {
        for (int32_t i = 0; i < context.activeBatchSize; ++i)
        {
            if (context.finishedStates[i])
            {
                continue; // Respect first-writer-wins (cancel may have fired).
            }
            auto& s = context.slotStreams[i];
            // Check EOS
            if (!context.tokenIds[i].empty() && context.tokenIds[i].back() == mTokenizer->getEosId())
            {
                context.finishedStates[i] = 1;
                if (s.channel)
                {
                    s.terminalReason = FinishReason::kEndId;
                }
                LOG_DEBUG("Batch %d finished, reason: EOS", i);
                continue;
            }
            // Check max length
            if (context.currentGenerateLengths[i] >= context.maxGenerateLength)
            {
                context.finishedStates[i] = 1;
                if (s.channel)
                {
                    s.terminalReason = FinishReason::kLength;
                }
                LOG_DEBUG(
                    "Batch %d finished, total tokens=%d, reason: max_length", i, context.currentGenerateLengths[i]);
                continue;
            }
        }
    };

    // Post-prefill finish-state detection: cancel FIRST (so it wins over natural
    // finish if both are observable), then natural EOS/length, then emit chunks.
    applyCancellationToFinishStates(context);
    updateFinishStates();
    emitChunks(context, *mTokenizer);

    // If everything finished during prefill, evict once so activeBatchSize reaches 0
    if (checkAllFinished() && context.activeBatchSize > 0)
    {
        bool const batchEvictStatus = performBatchEvict(context);
        if (!batchEvictStatus)
        {
            LOG_ERROR("Failed to perform batch eviction.");
            return false;
        }
    }

    while (!checkAllFinished())
    {
        // Observe any consumer cancels at the top of the iteration so they land
        // first in the per-slot terminalReason latch.
        applyCancellationToFinishStates(context);

        if (enableSpecDecode)
        {
            if (context.generationRound == 0)
            {
                bool const draftPrefillStatus = runDraftModelPrefill(context);
                if (!draftPrefillStatus)
                {
                    LOG_ERROR("Failed to execute prefill step for draft model.");
                    return false;
                }
            }
            else
            {
                bool const draftAcceptTokenStatus = runDraftModelAcceptToken(context);
                if (!draftAcceptTokenStatus)
                {
                    LOG_ERROR("Failed to execute accept token step for draft model.");
                    return false;
                }
            }

            bool const draftTreeConstructionStatus = constructDraftTree(context);
            if (!draftTreeConstructionStatus)
            {
                LOG_ERROR("Failed to construct draft tree.");
                return false;
            }

            bool const baseModelVerificationStatus = runBaseModelVerification(context);
            if (!baseModelVerificationStatus)
            {
                LOG_ERROR("Failed to verify token draft tree with base model.");
                return false;
            }
        }
        else
        {
            bool const vanillaDecodingStatus = runVanillaDecoding(context);
            if (!vanillaDecodingStatus)
            {
                LOG_ERROR("Failed to decode tokens with vanilla decoding.");
                return false;
            }
        }

        // Update iterations, check finish conditions and increment generation round
        updateFinishStates();
        emitChunks(context, *mTokenizer);
        context.generationRound += 1;

        // Perform batch eviction if needed (after verification, before updating finish states)
        bool const batchEvictStatus = performBatchEvict(context);
        if (!batchEvictStatus)
        {
            LOG_ERROR("Failed to perform batch eviction.");
            return false;
        }
    }

    if (context.activeBatchSize != 0)
    {
        LOG_ERROR("Eviction failure, there should be no active batch at the end of the inference. activeBatchSize: %d",
            context.activeBatchSize);
        return false;
    }

    // Record metrics - accumulate across all batches (active + evicted)
    int32_t totalReusedTokens = 0;
    int32_t totalComputedTokens = 0;
    int32_t totalGeneratedTokens = 0;
    int32_t totalIterations = 0;

    // Accumulate from completed batches
    for (auto const& [originalIdx, batchResult] : context.completedBatches)
    {
        int32_t rawPromptLength = static_cast<int32_t>(batchResult.rawBatchedInputIds.size());
        int32_t computedLength = batchResult.effectivePrefillLength;
        totalReusedTokens += (rawPromptLength - computedLength);
        totalComputedTokens += computedLength;
        totalGeneratedTokens += batchResult.generateLength;
        totalIterations += batchResult.actualIterations;
    }

    mPrefillMetrics.recordRun(totalReusedTokens, totalComputedTokens);
    if (enableSpecDecode)
    {
        mEagleGenerationMetrics.recordRun(totalIterations, totalGeneratedTokens);
    }
    else
    {
        mGenerationMetrics.recordRun(totalGeneratedTokens);
    }

    // Save output ids and decoded texts to response.
    // Maintain original batch order using original batch indices
    response.outputIds.clear();
    response.outputTexts.clear();
    response.outputIds.resize(context.completedBatches.size());
    response.outputTexts.resize(context.completedBatches.size());

    // Add outputs from completed batches (using saved original indices)
    for (auto const& [originalIdx, batchResult] : context.completedBatches)
    {
        int32_t genLength = batchResult.generateLength;

        // Log acceptance metrics for evicted batch
        if (enableSpecDecode)
        {
            int32_t const verificationTokens = genLength > 0 ? genLength - 1 : 0;
            float const acceptanceRate = batchResult.actualIterations > 0
                ? static_cast<float>(verificationTokens) / static_cast<float>(batchResult.actualIterations)
                : 0.0f;
            LOG_DEBUG(
                "Batch (completed with SpecDecode, original idx %d) - Acceptance rate: %.3f, Generated tokens: %d, "
                "Iterations: %d",
                originalIdx, acceptanceRate, genLength, batchResult.actualIterations);
        }

        // Extract generated tokens
        int32_t const totalLength = static_cast<int32_t>(batchResult.tokenIds.size());

        check::check(totalLength >= genLength, "Total length should be greater than or equal to generated length");
        response.outputIds[originalIdx] = std::vector<int32_t>(
            batchResult.tokenIds.begin() + (totalLength - genLength), batchResult.tokenIds.end());
        response.outputTexts[originalIdx] = mTokenizer->decode(response.outputIds[originalIdx], true);
    }

    return true;
}

bool LLMInferenceSpecDecodeRuntime::validateRequestConfig(LLMGenerationRequest const& request)
{
    int32_t const activeBatchSize = static_cast<int32_t>(request.requests.size());
    bool const hasAudio = std::any_of(
        request.requests.begin(), request.requests.end(), [](auto const& req) { return !req.audioBuffers.empty(); });
    bool const hasVision = std::any_of(
        request.requests.begin(), request.requests.end(), [](auto const& req) { return !req.imageBuffers.empty(); });

    if (activeBatchSize == 0)
    {
        LOG_ERROR("Empty request with no requests");
        return false;
    }

    if (activeBatchSize > mMaxRuntimeBatchSize)
    {
        LOG_ERROR(
            "Requested batch size %d exceeds maximum supported batch size %d", activeBatchSize, mMaxRuntimeBatchSize);
        return false;
    }
    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        if (request.requests[i].messages.empty())
        {
            LOG_ERROR("Request %d in batch is empty: no messages provided", i);
            return false;
        }
    }
    if (hasAudio && !mAudioRunner)
    {
        LOG_ERROR("Request contains audio input, but this runtime does not have an audio runner.");
        return false;
    }
    if (hasVision && !mVisionRunner)
    {
        LOG_ERROR("Request contains vision input, but this runtime does not have a vision runner.");
        return false;
    }

    return true;
}

bool LLMInferenceSpecDecodeRuntime::multiModalRuntimePreprocess(
    LLMGenerationRequest const& request, SpecDecodeInferenceContext& context, cudaStream_t stream)
{
    int32_t const activeBatchSize = static_cast<int32_t>(request.requests.size());
    bool const hasAudio = std::any_of(
        request.requests.begin(), request.requests.end(), [](auto const& req) { return !req.audioBuffers.empty(); });
    bool const hasVision = std::any_of(
        request.requests.begin(), request.requests.end(), [](auto const& req) { return !req.imageBuffers.empty(); });

    // Clear request-scoped multimodal state up front so previous requests cannot leak through reused runtime members.
    context.visualEmbeddings = std::nullopt;
    context.audioEmbeddings = std::nullopt;
    context.deepstackFeatures.clear();
    // Treat multimodal indices as request-scoped state. Only request paths that explicitly rebuild
    // mMultimodalIndices for the current request should observe a non-empty tensor downstream.
    check::check(mMultimodalIndices.reshape({0}), "Tensor reshape failed");

    // Mark multimodal preprocessing and inference for NVTX profiling
    NVTX_SCOPED_RANGE(nvtx_multimodal, "MULTIMODAL_PROCESSING", nvtx_colors::ORANGE);

    std::vector<std::vector<int32_t>> batchedInputIds;

    // Process audio inputs (if present)
    if (hasAudio && mAudioRunner)
    {
        LOG_INFO("Processing audio inputs");
        if (!mAudioRunner->preprocess(
                request, batchedInputIds, mTokenizer.get(), mBaseEngineRunner->getRopeCosSinCacheTensor(), stream))
        {
            LOG_ERROR("Audio preprocessing failed. This request cannot be handled.");
            return false;
        }

        if (!mAudioRunner->infer(stream))
        {
            LOG_ERROR("Audio inference failed. This request cannot be handled.");
            return false;
        }
    }

    // Process vision inputs (if present)
    if (hasVision && mVisionRunner)
    {
        LOG_INFO("Processing vision inputs");
        if (!mVisionRunner->preprocess(
                request, batchedInputIds, mTokenizer.get(), mBaseEngineRunner->getRopeCosSinCacheTensor(), stream))
        {
            LOG_ERROR("Vision preprocessing failed. This request cannot be handled.");
            return false;
        }

        if (!mVisionRunner->infer(stream))
        {
            LOG_ERROR("Vision inference failed. This request cannot be handled.");
            return false;
        }
    }

    if (!hasAudio && !hasVision)
    {
        for (int32_t i = 0; i < activeBatchSize; ++i)
        {
            batchedInputIds.push_back(mTokenizer->encode(request.formattedRequests[i].formattedCompleteRequest, false));
            if (batchedInputIds.back().empty())
            {
                LOG_ERROR("Failed to tokenize input text for request %d in batch", i);
                return false;
            }
        }
        if (mBaseEngineConfig.ropeConfig.type == RopeType::kMRope)
        {
            rt::Tensor& ropeCosSinCache = mBaseEngineRunner->getRopeCosSinCacheTensor();
            check::check(ropeCosSinCache.reshape({mBaseEngineConfig.maxSupportedBatchSize,
                             mBaseEngineConfig.maxKVCacheCapacity, mBaseEngineConfig.rotaryDim}),
                "Tensor reshape failed");
            kernel::initializeTextOnlyMRopeCosSin(ropeCosSinCache.dataPointer<float>(),
                mBaseEngineConfig.ropeConfig.rotaryTheta, mBaseEngineConfig.rotaryDim,
                mBaseEngineConfig.maxKVCacheCapacity, mBaseEngineConfig.maxSupportedBatchSize, stream);
        }
    }

    // Get embeddings from independent runners — gate on request having multimodal data,
    // not just runner existence, to avoid leaking stale embeddings from previous requests.
    rt::OptionalInputTensor visionEmbeddings
        = (hasVision && mVisionRunner) ? std::optional{std::ref(mVisionRunner->getOutputEmbedding())} : std::nullopt;
    rt::OptionalInputTensor audioEmbeddings
        = (hasAudio && mAudioRunner) ? std::optional{std::ref(mAudioRunner->getOutputEmbedding())} : std::nullopt;
    rt::OptionalInputTensors deepstackFeatures
        = (hasVision && mVisionRunner) ? mVisionRunner->getDeepstackFeatures() : rt::OptionalInputTensors{};

    context.visualEmbeddings = visionEmbeddings;
    context.deepstackFeatures = deepstackFeatures;
    context.audioEmbeddings = audioEmbeddings;

    // Populate system prompts and raw input IDs from batchedInputIds
    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        context.systemPrompts[i] = request.formattedRequests[i].formattedSystemPrompt;
        context.rawBatchedInputIds.push_back(batchedInputIds[i]);
    }

    return true;
}

bool LLMInferenceSpecDecodeRuntime::runBaseModelPrefill(SpecDecodeInferenceContext& context)
{
    TIME_STAGE(metrics::StageNames::kLLM_PREFILL, context.stream);
    NVTX_SCOPED_RANGE(nvtx_base_prefill,
        ("EAGLE_BASE_PREFILL[" + std::to_string(context.activeBatchSize) + "]").c_str(), nvtx_colors::BLUE);

    int32_t const activeBatchSize = context.activeBatchSize;

    // Prepare the inputs for prefill stage execution. The prefill length are already checked to be within the
    // engine supported range.
    int32_t const inputIdsLength
        = *std::max_element(context.effectivePrefillLengths.begin(), context.effectivePrefillLengths.end());

    check::check(mIdsInput.reshape({activeBatchSize, inputIdsLength}), "Tensor reshape failed");
    check::check(mContextLengthsInput.reshape({activeBatchSize}), "Tensor reshape failed");
    if (hasDraftModel())
    {
        check::check(
            mBaseHiddenStatesOutput.reshape({activeBatchSize, inputIdsLength, mBaseEngineConfig.outputHiddenDim}),
            "Tensor reshape failed");
    }
    check::check(mLogitsOutput.reshape({activeBatchSize, mBaseEngineConfig.outputVocabSize}), "Tensor reshape failed");

    // Setup the input tensors. ContextLen input is on CPU.
    int32_t* ctxLenData = mContextLengthsInput.dataPointer<int32_t>();
    int32_t* idsInputData = mIdsInput.dataPointer<int32_t>();

    // Pack all sequences into the host pinned memory first
    check::check(mHostPackedTokenIds.reshape({activeBatchSize, inputIdsLength}), "Tensor reshape failed");
    int32_t* hostPackedTokenIdsData = mHostPackedTokenIds.dataPointer<int32_t>();

    // Use actual prompt length (not padded length) for context_lengths to ensure we select the last real token
    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        ctxLenData[i]
            = context.effectivePrefillLengths[i]; // Use actual effective prefill length instead of padded length
        std::copy(context.tokenIds[i].begin(), context.tokenIds[i].end(), hostPackedTokenIdsData + i * inputIdsLength);
    }

    CUDA_CHECK(cudaMemcpyAsync(idsInputData, hostPackedTokenIdsData, activeBatchSize * inputIdsLength * sizeof(int32_t),
        cudaMemcpyHostToDevice, context.stream));

    // Perform embedding lookup for base model prefill
    check::check(mInputsEmbeds.reshape({activeBatchSize, inputIdsLength, mBaseEngineConfig.hiddenSize}),
        "Tensor reshape failed");

    // FIXME: Inconsistent handling for multimodal input. We should unify our embedding creation to adopt
    // mMultimodalIndices schema as unified method.
    //
    // Use the multimodal kernel (dispatches via explicit imageTokenId /
    // audioTokenId) when audio is present, or for image-only inference on
    // audio-capable model families that keep ``<image>`` in-stream. Legacy
    // vision-only families remap to ``vocabSize + k`` and fall through to
    // the path below.
    bool const useExplicitMultiModalId = context.audioEmbeddings.has_value()
        || (context.visualEmbeddings.has_value() && mBaseEngineConfig.audioTokenId != 0);
    bool const useIncrementalVisualID = context.visualEmbeddings.has_value();
    if (useExplicitMultiModalId)
    {
        auto const inputShape = mIdsInput.getShape();
        size_t const inputSizeBytes = inputShape.volume() * sizeof(int32_t);
        rt::Tensor inputIdsCPU(inputShape, rt::DeviceType::kCPU, mIdsInput.getDataType());
        CUDA_CHECK(
            cudaMemcpy(inputIdsCPU.rawPointer(), mIdsInput.rawPointer(), inputSizeBytes, cudaMemcpyDeviceToHost));

        std::optional<int32_t> audioTokenId
            = (mBaseEngineConfig.audioTokenId != 0) ? std::optional{mBaseEngineConfig.audioTokenId} : std::nullopt;
        std::optional<int32_t> imageTokenId
            = (mBaseEngineConfig.imageTokenId != 0) ? std::optional{mBaseEngineConfig.imageTokenId} : std::nullopt;
        rt::Tensor multimodalIndicesCPU
            = generateMultimodalIndices(inputIdsCPU, audioTokenId, imageTokenId, mBaseEngineConfig.vocabSize);
        auto const indicesShape = multimodalIndicesCPU.getShape();
        size_t const indicesSizeBytes = indicesShape.volume() * sizeof(int32_t);
        check::check(mMultimodalIndices.reshape(indicesShape), "Tensor reshape failed");
        CUDA_CHECK(cudaMemcpy(mMultimodalIndices.rawPointer(), multimodalIndicesCPU.rawPointer(), indicesSizeBytes,
            cudaMemcpyHostToDevice));

        kernel::embeddingLookupMultimodal(mIdsInput, mEmbedding.table, mEmbedding.scalesAsOptional(),
            std::optional{std::ref(mMultimodalIndices)}, imageTokenId, context.visualEmbeddings, audioTokenId,
            context.audioEmbeddings, mInputsEmbeds, context.stream);
    }
    else if (useIncrementalVisualID)
    {
        // Vision-only legacy path (``tokenId >= vocabSize`` remapping).
        rt::Tensor const& imageEmbedsTensor = context.visualEmbeddings.value().get();
        kernel::embeddingLookupWithImageInsertion(mIdsInput, mEmbedding.table, mEmbedding.scalesAsOptional(),
            imageEmbedsTensor, mInputsEmbeds, context.stream);
    }
    else
    {
        // Standard embedding lookup
        kernel::embeddingLookup(
            mIdsInput, mEmbedding.table, mEmbedding.scalesAsOptional(), mInputsEmbeds, context.stream);
    }

    // Process deepstack features
    rt::OptionalInputTensors deepstackEmbeds{};
    if (mBaseEngineConfig.numDeepstackFeatures > 0)
    {
        if (!context.deepstackFeatures.empty())
        {
            rt::OptionalInputTensor deepstackMultimodalIndices{std::nullopt};
            if (mMultimodalIndices.getShape().volume() > 0)
            {
                deepstackMultimodalIndices = std::ref(mMultimodalIndices);
            }

            for (int32_t idx = 0; idx < static_cast<int32_t>(context.deepstackFeatures.size()); ++idx)
            {
                rt::Tensor const& featureTensor = context.deepstackFeatures[idx].get();

                check::check(
                    mDeepstackEmbeds[idx].reshape({activeBatchSize, inputIdsLength, mBaseEngineConfig.hiddenSize}),
                    "Tensor reshape failed");
                kernel::assembleDeepstackEmbedding(mIdsInput, featureTensor, mBaseEngineConfig.vocabSize,
                    mDeepstackEmbeds[idx], context.stream, mBaseEngineConfig.imageTokenId, deepstackMultimodalIndices);

                deepstackEmbeds.push_back(std::ref(mDeepstackEmbeds[idx]));
            }
        }
        else
        {
            // Text-only request on a VLM model: pass zero-filled deepstack embeds to engine.
            // The engine requires these bindings even when no vision data is present.
            LOG_DEBUG(
                "Deepstack features configured but not available for this text-only request, using zero tensors.");
            for (int32_t idx = 0; idx < mBaseEngineConfig.numDeepstackFeatures; ++idx)
            {
                check::check(
                    mDeepstackEmbeds[idx].reshape({activeBatchSize, inputIdsLength, mBaseEngineConfig.hiddenSize}),
                    "Tensor reshape failed");
                CUDA_CHECK(cudaMemsetAsync(
                    mDeepstackEmbeds[idx].rawPointer(), 0, mDeepstackEmbeds[idx].getMemoryCapacity(), context.stream));
                deepstackEmbeds.push_back(std::ref(mDeepstackEmbeds[idx]));
            }
        }
    }

    // Only request hidden states output when draft model needs them
    rt::OptionalOutputTensor hiddenStatesOutput
        = hasDraftModel() ? rt::OptionalOutputTensor{std::ref(mBaseHiddenStatesOutput)} : std::nullopt;
    bool const prefillSuccess = mBaseEngineRunner->executePrefillStep(
        mInputsEmbeds, mContextLengthsInput, deepstackEmbeds, mLogitsOutput, hiddenStatesOutput, context.stream);
    if (!prefillSuccess)
    {
        LOG_ERROR("Failed to execute prefill step for base model.");
        return false;
    }

    // Sampling from the prefill stage logits follows the same policy as later vanilla decoding.
    check::check(mSamplingIndices.reshape({activeBatchSize, 1}), "Tensor reshape failed");
    if (shouldUseNonGreedySampling(context.temperature, context.topK, context.topP))
    {
        SamplingParams params(activeBatchSize, mBaseEngineConfig.outputVocabSize, context.temperature,
            static_cast<int32_t>(context.topK), context.topP);
        topKtopPSamplingFromLogits(mLogitsOutput, mSamplingIndices, params, mSamplingWorkspace, context.stream);
    }
    else
    {
        constexpr int32_t kSAMPLING_TOP_K = 1;
        selectAllTopK(
            mLogitsOutput, std::nullopt, mSamplingIndices, kSAMPLING_TOP_K, mSamplingWorkspace, context.stream);
    }

    // Apply vocabulary mapping if base model uses reduced vocabulary
    if (mBaseEngineConfig.reducedVocabSize > 0)
    {
        mapReducedVocabToFullVocab(mSamplingIndices, mBaseVocabMappingTable, context.stream);
    }

    check::check(mHostSelectedTokenIds.reshape({activeBatchSize}), "Tensor reshape failed");
    int32_t* hostSelectedTokenIdsData = mHostSelectedTokenIds.dataPointer<int32_t>();
    CUDA_CHECK(cudaMemcpyAsync(hostSelectedTokenIdsData, mSamplingIndices.rawPointer(),
        activeBatchSize * sizeof(int32_t), cudaMemcpyDeviceToHost, context.stream));
    CUDA_CHECK(cudaStreamSynchronize(context.stream));

    // Update tokenIds and generation length for each sequence
    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        if (!context.finishedStates[i])
        {
            context.tokenIds[i].push_back(hostSelectedTokenIdsData[i]);
            context.currentGenerateLengths[i] += 1;
        }
    }

    // The base prefill function produce output logits and (concatenated) hiddenStates for next step to use.
    return true;
}

bool LLMInferenceSpecDecodeRuntime::runDraftModelPrefill(SpecDecodeInferenceContext& context)
{
    assert(mDraftEngineRunner != nullptr);
    assert(mDraftingConfig.has_value());
    TIME_STAGE(metrics::StageNames::kEAGLE_DRAFT_PREFILL, context.stream);
    NVTX_SCOPED_RANGE(nvtx_draft_prefill,
        ("EAGLE_DRAFT_PREFILL[R" + std::to_string(context.generationRound) + ","
            + std::to_string(context.activeBatchSize) + "]")
            .c_str(),
        nvtx_colors::DARK_ORANGE);

    int32_t const activeBatchSize = context.activeBatchSize;

    // Implement the draft prefill execution logic, prepare the input ids and hidden states inputs for the
    // eagle draft engine. The formulation of the feature "vector" is F_n = F(H_n, Token_{n+1}), therefore we
    // need to trim out the first token of the sequence from the token_ids input.

    // Draft model prefill same amount of tokens as the base model prefill.
    int32_t const inputIdsLength
        = *std::max_element(context.effectivePrefillLengths.begin(), context.effectivePrefillLengths.end());

    check::check(mBaseHiddenStatesOutput.getShape()[0] == activeBatchSize
            && mBaseHiddenStatesOutput.getShape()[1] == inputIdsLength,
        "BaseHiddenStatesOutput shape [batch, seq_len, hidden_dim] shall match with [activeBatchSize, inputIdsLength, "
        "hidden_dim]");

    // Prepare input and output tensors.
    check::check(mIdsInput.reshape({activeBatchSize, inputIdsLength}), "Tensor reshape failed");
    check::check(
        mDraftHiddenStatesInput.reshape({activeBatchSize, inputIdsLength, mDraftEngineConfig->draftModelHiddenDim}),
        "Tensor reshape failed");
    check::check(
        mLogitsOutput.reshape({activeBatchSize, mDraftEngineConfig->draftModelVocabSize}), "Tensor reshape failed");
    check::check(mDraftHiddenStatesOutput.reshape({activeBatchSize, mDraftEngineConfig->draftModelHiddenDim}),
        "Tensor reshape failed");

    // Clear garbage data in the draft hidden inputs.
    CUDA_CHECK(cudaMemsetAsync(
        mDraftHiddenStatesInput.rawPointer(), 0, mDraftHiddenStatesInput.getMemoryCapacity(), context.stream));

    // Copy input IDs for each batch to host pinned memory first (skip first token for draft model)
    int32_t* idsInputData = mIdsInput.dataPointer<int32_t>();
    check::check(mHostPackedTokenIds.reshape({activeBatchSize, inputIdsLength}), "Tensor reshape failed");
    int32_t* hostPackedTokenIdsData = mHostPackedTokenIds.dataPointer<int32_t>();

    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        std::copy(
            context.tokenIds[i].begin() + 1, context.tokenIds[i].end(), hostPackedTokenIdsData + i * inputIdsLength);
    }

    CUDA_CHECK(cudaMemcpyAsync(idsInputData, hostPackedTokenIdsData, activeBatchSize * inputIdsLength * sizeof(int32_t),
        cudaMemcpyHostToDevice, context.stream));

    // Perform embedding lookup for draft model prefill
    check::check(mInputsEmbeds.reshape({activeBatchSize, inputIdsLength, mDraftEngineConfig->draftModelHiddenDim}),
        "Tensor reshape failed");

    if (context.visualEmbeddings.has_value())
    {
        // Use image insertion variant for multimodal models (draft model uses base model hidden dim / 3)
        rt::Tensor const& imageEmbedsTensor = context.visualEmbeddings.value().get();
        kernel::embeddingLookupWithImageInsertion(mIdsInput, mEmbedding.table, mEmbedding.scalesAsOptional(),
            imageEmbedsTensor, mInputsEmbeds, context.stream);
    }
    else
    {
        // Standard embedding lookup
        kernel::embeddingLookup(
            mIdsInput, mEmbedding.table, mEmbedding.scalesAsOptional(), mInputsEmbeds, context.stream);
    }

    bool const prefillSuccess = mDraftEngineRunner->executeEaglePrefillStep(mInputsEmbeds, mBaseHiddenStatesOutput,
        mDraftHiddenStatesInput, mContextLengthsInput, mLogitsOutput, mDraftHiddenStatesOutput,
        mBaseEngineRunner->getRopeCosSinCacheTensor(), context.stream);
    if (!prefillSuccess)
    {
        LOG_ERROR("Failed to execute prefill step for draft model.");
        return false;
    }
    // No need to do sampling, directly produce logits (mLogitsOutput) and hidden states (mDraftHiddenStatesOutput) for
    // the next step.
    return true;
}

bool LLMInferenceSpecDecodeRuntime::constructDraftTree(SpecDecodeInferenceContext& context)
{
    assert(mDraftEngineRunner != nullptr);
    assert(mDraftingConfig.has_value());
    TIME_STAGE(metrics::StageNames::kEAGLE_CONSTRUCT_DRAFT_TREE, context.stream);
    NVTX_SCOPED_RANGE(nvtx_construct_tree,
        ("EAGLE_CONSTRUCT_TREE[R" + std::to_string(context.generationRound) + ","
            + std::to_string(context.activeBatchSize) + "]")
            .c_str(),
        nvtx_colors::LIGHT_ORANGE);

    int32_t const activeBatchSize = context.activeBatchSize;

    // Core logic for eagle speculative decoding, construct the draft tree in an auto-regressive manner./
    // Inputs: Logits (mLogitsOutput) and draft hidden states (mDraftHiddenStatesOutput) from draft prefill
    // or draft model accept decoding operation.
    // Construct the draft tree table with multiple round of drafting. The descriptions of the draft tree are
    // stored in mDraftTokenIdsTable, mDraftTokenScoreTable, mDraftTokenPredecessorTable.

    // Reshape draft tree tensors to match activeBatchSize for dynamic batching
    int32_t const draftTopK = mDraftingConfig->draftingTopK;
    int32_t const draftFullTableLength = static_cast<int32_t>(mDraftTokenIdsFullTable.getShape()[1]);
    check::check(mDraftTokenIdsFullTable.reshape({activeBatchSize, draftFullTableLength}), "Tensor reshape failed");
    check::check(mDraftTokenScoreFullTable.reshape({activeBatchSize, draftFullTableLength}), "Tensor reshape failed");
    check::check(
        mDraftTokenPredecessorFullTable.reshape({activeBatchSize, draftFullTableLength}), "Tensor reshape failed");
    check::check(mDraftTreeRootTokenId.reshape({activeBatchSize}), "Tensor reshape failed");
    check::check(mDraftTokenIdsTable.reshape({activeBatchSize, draftTopK * draftTopK}), "Tensor reshape failed");
    check::check(mDraftTokenScoresTable.reshape({activeBatchSize, draftTopK * draftTopK}), "Tensor reshape failed");
    check::check(mDraftTokenIntermediateScores.reshape({activeBatchSize, draftTopK}), "Tensor reshape failed");
    check::check(mDraftTokenIntermediateParents.reshape({activeBatchSize, draftTopK}), "Tensor reshape failed");

    // Record root token (last committed token selected by base model) id for the draft tree for each batch.
    std::vector<int32_t> rootTokenIds(activeBatchSize);
    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        rootTokenIds[i] = context.tokenIds[i].back();
    }
    CUDA_CHECK(cudaMemcpyAsync(mDraftTreeRootTokenId.rawPointer(), rootTokenIds.data(),
        activeBatchSize * sizeof(int32_t), cudaMemcpyHostToDevice, context.stream));

    // Sampling from the logits output, collect draftTopK tokens as first level under "root".
    check::check(mSamplingIndices.reshape({activeBatchSize, draftTopK}), "Tensor reshape failed");
    check::check(mSamplingScores.reshape({activeBatchSize, draftTopK}), "Tensor reshape failed");
    selectAllTopK(
        mLogitsOutput, std::ref(mSamplingScores), mSamplingIndices, draftTopK, mSamplingWorkspace, context.stream);

    // Initialize data structures to describe the whole draft tree.
    kernel::initializeDraftTreeTables(mSamplingIndices, mSamplingScores, mDraftTreeRootTokenId, mDraftVocabMappingTable,
        mDraftTokenIdsFullTable, mDraftTokenScoreFullTable, mDraftTokenPredecessorFullTable, draftTopK, context.stream);
    // Reset hidden states output of base model and input hidden states of draft model to clear the garbage data.
    CUDA_CHECK(cudaMemsetAsync(
        mBaseHiddenStatesOutput.rawPointer(), 0, mBaseHiddenStatesOutput.getMemoryCapacity(), context.stream));
    CUDA_CHECK(cudaMemsetAsync(
        mDraftHiddenStatesInput.rawPointer(), 0, mDraftHiddenStatesInput.getMemoryCapacity(), context.stream));

    // Construct input tensors to feed into the eagle draft engine. With current implementation, for simplicity, we
    // will use padded input and only collect results from indices we need.
    int32_t const paddedDraftTreeSize = mDraftingConfig->draftingStep * draftTopK;
    check::check(mIdsInput.reshape({activeBatchSize, paddedDraftTreeSize}), "Tensor reshape failed");
    check::check(
        mBaseHiddenStatesOutput.reshape({activeBatchSize, paddedDraftTreeSize, mBaseEngineConfig.outputHiddenDim}),
        "Tensor reshape failed");
    check::check(mDraftHiddenStatesInput.reshape(
                     {activeBatchSize, paddedDraftTreeSize, mDraftEngineConfig->draftModelHiddenDim}),
        "Tensor reshape failed");
    check::check(mDraftTreeSize.reshape({activeBatchSize}), "Tensor reshape failed");
    check::check(
        mDraftTreeMask.reshape({activeBatchSize, paddedDraftTreeSize, paddedDraftTreeSize}), "Tensor reshape failed");
    // Assemble the initial draft tree input here since we need to copy out the data in draftHiddenStatesOutput prior to
    // reshaping it.
    kernel::assembleInitialDraftTreeInput(mDraftTokenIdsFullTable, mDraftHiddenStatesOutput, mIdsInput,
        mDraftHiddenStatesInput, mDraftTreeSize, mDraftTreeMask, draftTopK, context.stream);
    // Output tensors must be 3D: [batch_size, num_tokens, vocab_size/hidden_dim] for draft proposal
    check::check(mLogitsOutput.reshape({activeBatchSize, draftTopK, mDraftEngineConfig->draftModelVocabSize}),
        "Tensor reshape failed");
    check::check(
        mDraftHiddenStatesOutput.reshape({activeBatchSize, draftTopK, mDraftEngineConfig->draftModelHiddenDim}),
        "Tensor reshape failed");

    for (int32_t round = 0; round < mDraftingConfig->draftingStep - 1; round++)
    {
        if (round == 0)
        {
            // With first round of drafting, the input tensors have been assembled, we only need to save intermediate
            // information for the next step.
            kernel::assembleInitialIntermediateData(mSamplingScores, mDraftTokenIntermediateParents,
                mDraftTokenIntermediateScores, draftTopK, context.stream);
        }
        else
        {
            // Last round of drafting produce draftTopK x draftTopK candidate token for the layer, we need to pick the
            // top draftTopK, assemble input tensors, and save intermediate information.
            check::check(mSamplingIndices.reshape({activeBatchSize, draftTopK}), "Tensor reshape failed");
            check::check(mSamplingScores.reshape({activeBatchSize, draftTopK}), "Tensor reshape failed");
            selectAllTopK(mDraftTokenScoresTable, std::ref(mSamplingScores), mSamplingIndices, draftTopK,
                mSamplingWorkspace, context.stream);
            kernel::assembleDraftTreeInput(mDraftTokenIdsTable, mDraftHiddenStatesOutput, mSamplingIndices, mIdsInput,
                mDraftHiddenStatesInput, mDraftTreeSize, mDraftTreeMask, draftTopK, round, context.stream);
            kernel::assembleIntermediateData(mSamplingScores, mSamplingIndices, mDraftTokenIntermediateScores,
                mDraftTokenIntermediateParents, draftTopK, round, context.stream);
        }

        // Ensure output tensors are 3D before calling draft proposal
        check::check(mLogitsOutput.reshape({activeBatchSize, draftTopK, mDraftEngineConfig->draftModelVocabSize}),
            "Tensor reshape failed");
        check::check(
            mDraftHiddenStatesOutput.reshape({activeBatchSize, draftTopK, mDraftEngineConfig->draftModelHiddenDim}),
            "Tensor reshape failed");

        // Perform embedding lookup for draft proposal input (draft proposal only has text, no images)
        check::check(
            mInputsEmbeds.reshape({activeBatchSize, paddedDraftTreeSize, mDraftEngineConfig->draftModelHiddenDim}),
            "Tensor reshape failed");
        {
            kernel::embeddingLookup(
                mIdsInput, mEmbedding.table, mEmbedding.scalesAsOptional(), mInputsEmbeds, context.stream);
        }

        // Invoke the eagle draft engine to produce the new round of logits and hidden states.
        bool const draftProposalStatus = mDraftEngineRunner->executeEagleDraftProposalStep(mInputsEmbeds,
            mBaseHiddenStatesOutput, mDraftHiddenStatesInput, mDraftTreeSize, mDraftTreeMask, mLogitsOutput,
            mDraftHiddenStatesOutput, context.stream);
        if (!draftProposalStatus)
        {
            LOG_ERROR("Failed to execute draft proposal step for draft model.");
            return false;
        }
        // Collect TopK results from each lane of output logits.
        // mLogitsOutput is now 3D: {activeBatchSize, draftTopK, vocabSize}
        // Reshape to 2D for selectAllTopK: {activeBatchSize * draftTopK, vocabSize}
        check::check(mLogitsOutput.reshape({activeBatchSize * draftTopK, mDraftEngineConfig->draftModelVocabSize}),
            "Tensor reshape failed");
        check::check(
            mDraftHiddenStatesOutput.reshape({activeBatchSize * draftTopK, mDraftEngineConfig->draftModelHiddenDim}),
            "Tensor reshape failed");
        check::check(mSamplingIndices.reshape({activeBatchSize * draftTopK, draftTopK}), "Tensor reshape failed");
        check::check(mSamplingScores.reshape({activeBatchSize * draftTopK, draftTopK}), "Tensor reshape failed");
        selectAllTopK(
            mLogitsOutput, std::ref(mSamplingScores), mSamplingIndices, draftTopK, mSamplingWorkspace, context.stream);

        // Reshape sampling indices/scores back to the expected format for subsequent kernel calls
        // Note: mDraftHiddenStatesOutput stays in 2D format for assembleDraftTreeInput in next round
        check::check(mSamplingIndices.reshape({activeBatchSize, draftTopK * draftTopK}), "Tensor reshape failed");
        check::check(mSamplingScores.reshape({activeBatchSize, draftTopK * draftTopK}), "Tensor reshape failed");

        // Update the draft tree tables with the new topK results. translate draft vocab token towards full vocab size.
        kernel::computeCuScoresAndTranslateToken(mSamplingIndices, mSamplingScores, mDraftTokenIntermediateScores,
            mDraftVocabMappingTable, mDraftTokenIdsTable, mDraftTokenScoresTable, draftTopK, context.stream);
        // Update results in the full draft table
        kernel::updateDraftTreeFullTables(mDraftTokenIdsTable, mDraftTokenScoresTable, mDraftTokenIntermediateParents,
            mDraftTokenIdsFullTable, mDraftTokenScoreFullTable, mDraftTokenPredecessorFullTable, draftTopK, round,
            context.stream);
    }

    // We have constructed the data structure for the draft table, now we need to pick the top candidates and produce
    // the verify tree and pass into the base model for verification.
    check::check(mSamplingIndices.reshape({activeBatchSize, mDraftingConfig->verifyTreeSize}), "Tensor reshape failed");
    selectAllTopK(mDraftTokenScoreFullTable, std::nullopt, mSamplingIndices, mDraftingConfig->verifyTreeSize,
        mSamplingWorkspace, context.stream);

    check::check(mIdsInput.reshape({activeBatchSize, mDraftingConfig->verifyTreeSize}), "Tensor reshape failed");
    check::check(
        mDraftTreeMask.reshape({activeBatchSize, mDraftingConfig->verifyTreeSize, mDraftingConfig->verifyTreeSize}),
        "Tensor reshape failed");
    kernel::constructVerificationDraftTree(mDraftTokenIdsFullTable, mDraftTokenPredecessorFullTable, mSamplingIndices,
        mIdsInput, mDraftTreeMask, context.stream);

    // This function will produce mIdsInput and mDraftTreeMask to describe established draft tree.
    return true;
}

bool LLMInferenceSpecDecodeRuntime::runBaseModelVerification(SpecDecodeInferenceContext& context)
{
    assert(mDraftEngineRunner != nullptr);
    assert(mDraftingConfig.has_value());
    TIME_STAGE(metrics::StageNames::kEAGLE_BASE_VERIFICATION, context.stream);
    NVTX_SCOPED_RANGE(nvtx_verify,
        ("EAGLE_VERIFY[R" + std::to_string(context.generationRound) + "," + std::to_string(context.activeBatchSize)
            + "]")
            .c_str(),
        nvtx_colors::MAGENTA);

    int32_t const activeBatchSize = context.activeBatchSize;

    // This function will consume idsInput and draftTreeMask. Use base model to verify the draft tree.
    // We need to collect the logits and hidden states (for further drafting step).
    check::check(
        mIdsInput.getShape()[0] == activeBatchSize && mIdsInput.getShape()[1] == mDraftingConfig->verifyTreeSize,
        "IdsInput shall have shape [batch_size, verify_tree_size]");
    check::check(mDraftTreeMask.getShape()[0] == activeBatchSize
            && mDraftTreeMask.getShape()[1] == mDraftingConfig->verifyTreeSize
            && mDraftTreeMask.getShape()[2] == mDraftingConfig->verifyTreeSize,
        "DraftTreeMask shall have shape [batch_size, verify_tree_size, verify_tree_size]");

    // Perform embedding lookup for base model verification (Eagle base tree decoding only has text, no images)
    check::check(
        mInputsEmbeds.reshape({activeBatchSize, mDraftingConfig->verifyTreeSize, mBaseEngineConfig.hiddenSize}),
        "Tensor reshape failed");
    {
        kernel::embeddingLookup(
            mIdsInput, mEmbedding.table, mEmbedding.scalesAsOptional(), mInputsEmbeds, context.stream);
    }

    // Engine expects 2D tensors: [batch_size * verify_tree_size, vocab_size/hidden_dim]
    int32_t const selectTokenSize = activeBatchSize * mDraftingConfig->verifyTreeSize;
    check::check(mLogitsOutput.reshape({selectTokenSize, mBaseEngineConfig.outputVocabSize}), "Tensor reshape failed");
    check::check(
        mBaseHiddenStatesOutput.reshape({selectTokenSize, mBaseEngineConfig.outputHiddenDim}), "Tensor reshape failed");

    bool const verifySuccess = mBaseEngineRunner->executeEagleBaseTreeDecodingStep(
        mInputsEmbeds, mDraftTreeMask, mLogitsOutput, mBaseHiddenStatesOutput, context.stream);
    if (!verifySuccess)
    {
        LOG_ERROR("Failed to execute base tree verification step for base model.");
        return false;
    }

    // Reshape accepted token tensors to match activeBatchSize for dynamic batching
    int32_t const maxAcceptDepth = mDraftingConfig->draftingStep + 1;
    check::check(mAcceptedTokenIds.reshape({activeBatchSize, maxAcceptDepth}), "Tensor reshape failed");
    check::check(mAcceptedTokenIndices.reshape({activeBatchSize, maxAcceptDepth}), "Tensor reshape failed");
    check::check(mAcceptLength.reshape({activeBatchSize}), "Tensor reshape failed");

    // Collected accepted token ids and indices. Use sampling workspace for eagle accept process.
    // Pass vocab mapping table if base model uses reduced vocabulary
    // Note: accepted tokens will already be in full vocab space after eagleAccept (mapping happens inside kernel)
    rt::OptionalInputTensor vocabMappingTable
        = (mBaseEngineConfig.reducedVocabSize > 0) ? std::optional{std::ref(mBaseVocabMappingTable)} : std::nullopt;
    kernel::eagleAccept(mLogitsOutput, mIdsInput, mDraftTreeMask, mAcceptedTokenIds, mAcceptedTokenIndices,
        mAcceptLength, vocabMappingTable, mSamplingWorkspace.rawPointer(), mSamplingWorkspace.getMemoryCapacity(),
        context.stream);

    // Inplace update the KVCache and input hidden states from the accepted token indices.
    // Also commit KVCache to reflect the latest KVCache length (We can only do this after knowing how many tokens are
    // accepted).
    auto& baseCacheManager = mBaseEngineRunner->getCacheManager();
    rt::Tensor const& kvCacheLengths = baseCacheManager.getKVCacheLengths();

    // The EAGLE base verify is per-head-dim-group batched: one launch per group covers
    // every layer in that group, addressing per-layer storage through a device-resident
    // KVLayerInfo array owned by HybridCacheManager. Uniform models (every model that goes
    // through EAGLE today) yield a single group; hybrid Gemma4-style layouts would yield
    // one group per distinct headDim with no further code change.
    auto const kvHeadDimGroups = baseCacheManager.getKVHeadDimGroups();
    auto const kvCacheType = baseCacheManager.getKVCacheManager().getConfig().kvCacheType;

    // Reshape input hidden states from 2D [batch*verify_tree_size, hidden_dim] to 3D
    // [batch, verify_tree_size, hidden_dim]
    check::check(mBaseHiddenStatesOutput.reshape(
                     {activeBatchSize, mDraftingConfig->verifyTreeSize, mBaseEngineConfig.outputHiddenDim}),
        "Tensor reshape failed");

    // INPLACE updates:
    //   - eagleBaseCommitKVCache rewrites each layer's KV cache so accepted tokens occupy
    //     contiguous slots starting at pastKvCacheLength.
    //   - eagleBaseAssembleHiddenState compacts the hidden-state buffer:
    //       Before: [Batch0: Token0...Token59][Batch1: Token0...Token59]...   (stride=60 per batch)
    //       After:  [Batch0: Sel0...Sel6][Batch1: Sel0...Sel6]...              (stride=maxAcceptDepth)
    for (auto const& group : kvHeadDimGroups)
    {
        kernel::eagleBaseCommitKVCache(mAcceptedTokenIndices, mAcceptLength, kvCacheLengths, group.deviceLayerInfos,
            group.numLayers, group.headDim, group.maxKVHeads, activeBatchSize, maxAcceptDepth, kvCacheType,
            context.stream);
    }
    kernel::eagleBaseAssembleHiddenState(mAcceptedTokenIndices, mAcceptLength, mBaseHiddenStatesOutput, context.stream);

    baseCacheManager.commitSequenceLength(mAcceptLength, context.stream);

    // Reshape to reflect the compacted layout [batch, maxAcceptDepth, hiddenDim]
    check::check(mBaseHiddenStatesOutput.reshape({activeBatchSize, maxAcceptDepth, mBaseEngineConfig.outputHiddenDim}),
        "Tensor reshape failed");

    // Pull collected results from device to host pinned memory for all batches
    check::check(mHostAcceptLengths.reshape({activeBatchSize}), "Tensor reshape failed");
    check::check(mHostAcceptedTokenIds.reshape({activeBatchSize, maxAcceptDepth}), "Tensor reshape failed");
    int32_t* hostAcceptLengthsData = mHostAcceptLengths.dataPointer<int32_t>();
    int32_t* hostAcceptedTokenIdsData = mHostAcceptedTokenIds.dataPointer<int32_t>();

    CUDA_CHECK(cudaMemcpyAsync(hostAcceptLengthsData, mAcceptLength.rawPointer(), activeBatchSize * sizeof(int32_t),
        cudaMemcpyDeviceToHost, context.stream));
    CUDA_CHECK(cudaMemcpyAsync(hostAcceptedTokenIdsData, mAcceptedTokenIds.rawPointer(),
        activeBatchSize * maxAcceptDepth * sizeof(int32_t), cudaMemcpyDeviceToHost, context.stream));
    CUDA_CHECK(cudaStreamSynchronize(context.stream));

    // Update tokenIds and check for EOS for each batch
    for (int32_t batchIdx = 0; batchIdx < activeBatchSize; ++batchIdx)
    {
        int32_t const acceptLength = hostAcceptLengthsData[batchIdx];
        for (int32_t i = 0; i < acceptLength; i++)
        {
            int32_t const token = hostAcceptedTokenIdsData[batchIdx * maxAcceptDepth + i];
            context.tokenIds[batchIdx].push_back(token);
            context.currentGenerateLengths[batchIdx]++;

            // Abandon token after EOS if they exist.
            if (token == mTokenizer->getEosId())
            {
                break;
            }
        }
    }

    // Produce base model hidden states for the next step
    // Note: Hidden states shape is already set by the kernel
    return true;
}

bool LLMInferenceSpecDecodeRuntime::runVanillaDecoding(SpecDecodeInferenceContext& context)
{
    TIME_STAGE(metrics::StageNames::kLLM_GENERATION, context.stream);
    NVTX_SCOPED_RANGE(nvtx_vanilla_decoding,
        ("VANILLA_DECODING[R" + std::to_string(context.generationRound) + "," + std::to_string(context.activeBatchSize)
            + "]")
            .c_str(),
        nvtx_colors::BLUE);

    int32_t const activeBatchSize = context.activeBatchSize;
    check::check(mHostPackedTokenIds.reshape({activeBatchSize}), "Tensor reshape failed");
    int32_t* hostPackedTokenIdsData = mHostPackedTokenIds.dataPointer<int32_t>();

    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        int32_t const lastTokenId = context.tokenIds[i].back();
        hostPackedTokenIdsData[i] = lastTokenId;
    }

    check::check(mIdsInput.reshape({activeBatchSize, 1}), "Tensor reshape failed");
    CUDA_CHECK(cudaMemcpyAsync(mIdsInput.rawPointer(), mHostPackedTokenIds.rawPointer(),
        activeBatchSize * sizeof(int32_t), cudaMemcpyHostToDevice, context.stream));

    check::check(mInputsEmbeds.reshape({activeBatchSize, 1, mBaseEngineConfig.hiddenSize}), "Tensor reshape failed");
    {
        kernel::embeddingLookup(
            mIdsInput, mEmbedding.table, mEmbedding.scalesAsOptional(), mInputsEmbeds, context.stream);
    }

    check::check(mLogitsOutput.reshape({activeBatchSize, mBaseEngineConfig.outputVocabSize}), "Tensor reshape failed");

    // No hidden states output needed for vanilla decoding.
    rt::OptionalOutputTensor const outputHiddenStates{std::nullopt};
    bool const vanillaDecodingSuccess = mBaseEngineRunner->executeVanillaDecodingStep(
        mInputsEmbeds, mLogitsOutput, outputHiddenStates, context.stream);
    if (!vanillaDecodingSuccess)
    {
        LOG_ERROR("Failed to execute vanilla decoding step for base model.");
        return false;
    }

    // Use topKtopPSampling when sampling params differ from greedy defaults (temperature=1.0, topK<=1, topP=1.0).
    // Temperature <= 1e-3 is treated as greedy to avoid softmax numerical instability.
    check::check(mSamplingIndices.reshape({activeBatchSize, 1}), "Tensor reshape failed");
    if (shouldUseNonGreedySampling(context.temperature, context.topK, context.topP))
    {
        SamplingParams params(activeBatchSize, mBaseEngineConfig.outputVocabSize, context.temperature,
            static_cast<int32_t>(context.topK), context.topP);
        topKtopPSamplingFromLogits(mLogitsOutput, mSamplingIndices, params, mSamplingWorkspace, context.stream);
    }
    else
    {
        // Greedy decoding (temperature ~= 0 or default)
        constexpr int32_t kSAMPLING_TOP_K = 1;
        selectAllTopK(
            mLogitsOutput, std::nullopt, mSamplingIndices, kSAMPLING_TOP_K, mSamplingWorkspace, context.stream);
    }

    // Apply vocabulary mapping if base model uses reduced vocabulary
    if (mBaseEngineConfig.reducedVocabSize > 0)
    {
        mapReducedVocabToFullVocab(mSamplingIndices, mBaseVocabMappingTable, context.stream);
    }

    check::check(mHostSelectedTokenIds.reshape({activeBatchSize}), "Tensor reshape failed");
    int32_t* hostSelectedTokenIdsData = mHostSelectedTokenIds.dataPointer<int32_t>();
    CUDA_CHECK(cudaMemcpyAsync(hostSelectedTokenIdsData, mSamplingIndices.rawPointer(),
        activeBatchSize * sizeof(int32_t), cudaMemcpyDeviceToHost, context.stream));
    CUDA_CHECK(cudaStreamSynchronize(context.stream));

    // Update tokenIds and generation length for each sequence
    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        context.tokenIds[i].push_back(hostSelectedTokenIdsData[i]);
        context.currentGenerateLengths[i] += 1;
    }

    return true;
}

bool LLMInferenceSpecDecodeRuntime::runDraftModelAcceptToken(SpecDecodeInferenceContext& context)
{
    assert(mDraftEngineRunner != nullptr);
    assert(mDraftingConfig.has_value());
    NVTX_SCOPED_RANGE(nvtx_draft_accept,
        ("EAGLE_DRAFT_ACCEPT[R" + std::to_string(context.generationRound) + ","
            + std::to_string(context.activeBatchSize) + "]")
            .c_str(),
        nvtx_colors::YELLOW);

    int32_t const activeBatchSize = context.activeBatchSize;

    // Base model verifiction function is responsible for producing the output with correct shape.
    // Shape is [activeBatchSize, accepted_length, hidden_dim]
    int64_t const inputIdsLength = mBaseHiddenStatesOutput.getShape()[1];

    // Prepare input and output tensors.
    check::check(mIdsInput.reshape({activeBatchSize, inputIdsLength}), "Tensor reshape failed");
    check::check(
        mDraftHiddenStatesInput.reshape({activeBatchSize, inputIdsLength, mDraftEngineConfig->draftModelHiddenDim}),
        "Tensor reshape failed");
    check::check(
        mLogitsOutput.reshape({activeBatchSize, mDraftEngineConfig->draftModelVocabSize}), "Tensor reshape failed");
    check::check(mDraftHiddenStatesOutput.reshape({activeBatchSize, mDraftEngineConfig->draftModelHiddenDim}),
        "Tensor reshape failed");

    // Clear garbage data in the draft hidden inputs.
    CUDA_CHECK(cudaMemsetAsync(
        mDraftHiddenStatesInput.rawPointer(), 0, mDraftHiddenStatesInput.getMemoryCapacity(), context.stream));

    // Prepare the IdsInput, we need to copy from mAcceptedTokenIds for each batch.
    // mAcceptedTokenIds is [activeBatchSize, maxAcceptDepth], we copy the first inputIdsLength tokens
    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        CUDA_CHECK(cudaMemcpyAsync(static_cast<int32_t*>(mIdsInput.rawPointer()) + i * inputIdsLength,
            static_cast<int32_t*>(mAcceptedTokenIds.rawPointer()) + i * mAcceptedTokenIds.getShape()[1],
            inputIdsLength * sizeof(int32_t), cudaMemcpyDeviceToDevice, context.stream));
    }

    // Perform embedding lookup for draft model accept decode token (only text, no images)
    check::check(mInputsEmbeds.reshape({activeBatchSize, inputIdsLength, mDraftEngineConfig->draftModelHiddenDim}),
        "Tensor reshape failed");
    {
        kernel::embeddingLookup(
            mIdsInput, mEmbedding.table, mEmbedding.scalesAsOptional(), mInputsEmbeds, context.stream);
    }

    bool const acceptTokenSuccess
        = mDraftEngineRunner->executeEagleAcceptDecodeTokenStep(mInputsEmbeds, mBaseHiddenStatesOutput,
            mDraftHiddenStatesInput, mAcceptLength, mLogitsOutput, mDraftHiddenStatesOutput, context.stream);
    if (!acceptTokenSuccess)
    {
        LOG_ERROR("Failed to execute accept token step for draft model.");
        return false;
    }
    // No need to do sampling, directly produce logits (mLogitsOutput) and hidden states (mDraftHiddenStatesOutput) for
    // the next step.
    return true;
}

bool LLMInferenceSpecDecodeRuntime::captureDecodingCUDAGraph(cudaStream_t stream)
{
    bool draftProposalCaptureStatus{true};
    bool draftAcceptCaptureStatus{true};
    bool baseVerificationCaptureStatus{true};
    bool baseVanillaDecodingCaptureStatus{true};

    // Capture CUDA graph for all supported batch sizes
    for (int32_t batchSize = 1; batchSize <= mMaxRuntimeBatchSize; ++batchSize)
    {
        // Draft-specific captures (only when draft model is present)
        if (mDraftEngineRunner != nullptr)
        {
            int32_t const draftTopK = mDraftingConfig->draftingTopK;
            int32_t const paddedDraftTreeSize = mDraftingConfig->draftingStep * draftTopK;
            int32_t const draftingStep = mDraftingConfig->draftingStep;

            // Draft proposal capture
            check::check(
                mBaseHiddenStatesOutput.reshape({batchSize, paddedDraftTreeSize, mBaseEngineConfig.outputHiddenDim}),
                "Tensor reshape failed");
            check::check(mDraftHiddenStatesInput.reshape(
                             {batchSize, paddedDraftTreeSize, mDraftEngineConfig->draftModelHiddenDim}),
                "Tensor reshape failed");
            check::check(mDraftTreeSize.reshape({batchSize}), "Tensor reshape failed");
            check::check(
                mDraftTreeMask.reshape({batchSize, paddedDraftTreeSize, paddedDraftTreeSize}), "Tensor reshape failed");
            // Output tensors must be 3D: [batch_size, num_tokens, vocab_size/hidden_dim] not 2D
            check::check(mLogitsOutput.reshape({batchSize, draftTopK, mDraftEngineConfig->draftModelVocabSize}),
                "Tensor reshape failed");
            check::check(
                mDraftHiddenStatesOutput.reshape({batchSize, draftTopK, mDraftEngineConfig->draftModelHiddenDim}),
                "Tensor reshape failed");
            check::check(
                mInputsEmbeds.reshape({batchSize, paddedDraftTreeSize, mDraftEngineConfig->draftModelHiddenDim}),
                "Tensor reshape failed");

            draftProposalCaptureStatus &= mDraftEngineRunner->captureEagleDraftProposalCudaGraph(mInputsEmbeds,
                mBaseHiddenStatesOutput, mDraftHiddenStatesInput, mDraftTreeSize, mDraftTreeMask, mLogitsOutput,
                mDraftHiddenStatesOutput, stream);

            // Draft accept decode token capture
            check::check(
                mLogitsOutput.reshape({batchSize, mDraftEngineConfig->draftModelVocabSize}), "Tensor reshape failed");
            check::check(mDraftHiddenStatesOutput.reshape({batchSize, mDraftEngineConfig->draftModelHiddenDim}),
                "Tensor reshape failed");

            // Don't pass multimodal embeddings during CUDA graph capture as they are invalid
            for (int32_t acceptLength = 1; acceptLength <= draftingStep + 1; acceptLength++)
            {
                check::check(
                    mBaseHiddenStatesOutput.reshape({batchSize, acceptLength, mBaseEngineConfig.outputHiddenDim}),
                    "Tensor reshape failed");
                check::check(
                    mDraftHiddenStatesInput.reshape({batchSize, acceptLength, mDraftEngineConfig->draftModelHiddenDim}),
                    "Tensor reshape failed");

                // Create a temporary acceptLength tensor for CUDA graph capture
                // All batches use the same acceptLength during graph capture
                check::check(mAcceptLength.reshape({batchSize}), "Tensor reshape failed");
                std::vector<int32_t> acceptLengthsVec(batchSize, acceptLength);
                CUDA_CHECK(cudaMemcpyAsync(mAcceptLength.rawPointer(), acceptLengthsVec.data(),
                    batchSize * sizeof(int32_t), cudaMemcpyHostToDevice, stream));

                check::check(mInputsEmbeds.reshape({batchSize, acceptLength, mDraftEngineConfig->draftModelHiddenDim}),
                    "Tensor reshape failed");

                draftAcceptCaptureStatus &= mDraftEngineRunner->captureEagleAcceptDecodeTokenCudaGraph(mInputsEmbeds,
                    mBaseHiddenStatesOutput, mDraftHiddenStatesInput, mAcceptLength, mLogitsOutput,
                    mDraftHiddenStatesOutput, stream);
            }

            // Base verification capture
            // Engine expects 2D tensors: [batch_size * verify_tree_size, vocab_size/hidden_dim]
            int32_t const selectTokenSize = batchSize * mDraftingConfig->verifyTreeSize;
            check::check(
                mLogitsOutput.reshape({selectTokenSize, mBaseEngineConfig.outputVocabSize}), "Tensor reshape failed");
            check::check(mBaseHiddenStatesOutput.reshape({selectTokenSize, mBaseEngineConfig.outputHiddenDim}),
                "Tensor reshape failed");
            check::check(
                mDraftTreeMask.reshape({batchSize, mDraftingConfig->verifyTreeSize, mDraftingConfig->verifyTreeSize}),
                "Tensor reshape failed");

            check::check(
                mInputsEmbeds.reshape({batchSize, mDraftingConfig->verifyTreeSize, mBaseEngineConfig.hiddenSize}),
                "Tensor reshape failed");

            baseVerificationCaptureStatus &= mBaseEngineRunner->captureEagleBaseTreeDecodingCudaGraph(
                mInputsEmbeds, mDraftTreeMask, mLogitsOutput, mBaseHiddenStatesOutput, mEmptyLoraWeightsName, stream);
            if (mBaseEngineConfig.maxSupportedLoraRank > 0)
            {
                for (auto const& loraWeightsName : mBaseEngineRunner->getAvailableLoraWeights())
                {
                    baseVerificationCaptureStatus &= mBaseEngineRunner->captureEagleBaseTreeDecodingCudaGraph(
                        mInputsEmbeds, mDraftTreeMask, mLogitsOutput, mBaseHiddenStatesOutput, loraWeightsName, stream);
                }
            }
        }

        // Base Vanilla Decoding capture (always needed).
        check::check(mInputsEmbeds.reshape({batchSize, 1, mBaseEngineConfig.hiddenSize}), "Tensor reshape failed");
        check::check(mLogitsOutput.reshape({batchSize, mBaseEngineConfig.outputVocabSize}), "Tensor reshape failed");

        baseVanillaDecodingCaptureStatus &= mBaseEngineRunner->captureVanillaDecodingCudaGraph(
            mInputsEmbeds, mLogitsOutput, mEmptyLoraWeightsName, stream);
        if (mBaseEngineConfig.maxSupportedLoraRank > 0)
        {
            for (auto const& loraWeightsName : mBaseEngineRunner->getAvailableLoraWeights())
            {
                baseVanillaDecodingCaptureStatus &= mBaseEngineRunner->captureVanillaDecodingCudaGraph(
                    mInputsEmbeds, mLogitsOutput, loraWeightsName, stream);
            }
        }
    }

    bool const captureStatus = draftProposalCaptureStatus && draftAcceptCaptureStatus && baseVerificationCaptureStatus
        && baseVanillaDecodingCaptureStatus;
    if (captureStatus)
    {
        LOG_INFO("Successfully captured decoding CUDA graphs for all stages.");
    }
    else
    {
        LOG_WARNING(
            "Failed to capture decoding CUDA graphs for some stages. The inference can proceed without "
            "CUDA graph capture, but at cost of performance degradation.");
    }

    return captureStatus;
}

void LLMInferenceSpecDecodeRuntime::restoreRecurrentStates(
    int32_t batchIdx, SystemPromptKVCache const& cachedStates, cudaStream_t stream)
{
    rt::MambaCacheManager& mambaCache = mBaseEngineRunner->getCacheManager().getMambaCacheManager();
    rt::MambaCacheManager::Config const& mambaConfig = mambaCache.getConfig();
    size_t const recurrentElemSize = rt::utils::getTypeSize(mambaConfig.recurrentStateType);
    size_t const convElemSize = rt::utils::getTypeSize(mambaConfig.convStateType);
    size_t const recurrentBatchBytes = static_cast<size_t>(mambaConfig.recurrentStateNumHeads
                                           * mambaConfig.recurrentStateHeadDim * mambaConfig.recurrentStateSize)
        * recurrentElemSize;
    size_t const convBatchBytes = static_cast<size_t>(mambaConfig.convDim * mambaConfig.convKernel) * convElemSize;

    for (int32_t layer = 0; layer < mBaseEngineConfig.numLinearAttnLayers; ++layer)
    {
        rt::Tensor& recurrentLayer = mambaCache.getRecurrentState(layer);
        rt::Tensor& convLayer = mambaCache.getConvState(layer);

        auto* recurrentDst = static_cast<std::byte*>(recurrentLayer.rawPointer()) + batchIdx * recurrentBatchBytes;
        auto* convDst = static_cast<std::byte*>(convLayer.rawPointer()) + batchIdx * convBatchBytes;

        if (layer < static_cast<int32_t>(cachedStates.recurrentStateContents.size()))
        {
            CUDA_CHECK(cudaMemcpyAsync(recurrentDst, cachedStates.recurrentStateContents[layer].rawPointer(),
                recurrentBatchBytes, cudaMemcpyDeviceToDevice, stream));
        }
        else
        {
            CUDA_CHECK(cudaMemsetAsync(recurrentDst, 0, recurrentBatchBytes, stream));
        }

        if (layer < static_cast<int32_t>(cachedStates.convStateContents.size()))
        {
            CUDA_CHECK(cudaMemcpyAsync(convDst, cachedStates.convStateContents[layer].rawPointer(), convBatchBytes,
                cudaMemcpyDeviceToDevice, stream));
        }
        else
        {
            CUDA_CHECK(cudaMemsetAsync(convDst, 0, convBatchBytes, stream));
        }
    }
}

void LLMInferenceSpecDecodeRuntime::zeroRecurrentStates(int32_t batchIdx, cudaStream_t stream)
{
    rt::MambaCacheManager& mambaCache = mBaseEngineRunner->getCacheManager().getMambaCacheManager();
    rt::MambaCacheManager::Config const& mambaConfig = mambaCache.getConfig();
    size_t const recurrentElemSize = rt::utils::getTypeSize(mambaConfig.recurrentStateType);
    size_t const convElemSize = rt::utils::getTypeSize(mambaConfig.convStateType);
    size_t const recurrentBatchBytes = static_cast<size_t>(mambaConfig.recurrentStateNumHeads
                                           * mambaConfig.recurrentStateHeadDim * mambaConfig.recurrentStateSize)
        * recurrentElemSize;
    size_t const convBatchBytes = static_cast<size_t>(mambaConfig.convDim * mambaConfig.convKernel) * convElemSize;

    for (int32_t layer = 0; layer < mBaseEngineConfig.numLinearAttnLayers; ++layer)
    {
        rt::Tensor& recurrentLayer = mambaCache.getRecurrentState(layer);
        rt::Tensor& convLayer = mambaCache.getConvState(layer);

        auto* recurrentDst = static_cast<std::byte*>(recurrentLayer.rawPointer()) + batchIdx * recurrentBatchBytes;
        auto* convDst = static_cast<std::byte*>(convLayer.rawPointer()) + batchIdx * convBatchBytes;
        CUDA_CHECK(cudaMemsetAsync(recurrentDst, 0, recurrentBatchBytes, stream));
        CUDA_CHECK(cudaMemsetAsync(convDst, 0, convBatchBytes, stream));
    }
}

bool LLMInferenceSpecDecodeRuntime::setUpForPrefillExecution(SpecDecodeInferenceContext& context)
{
    NVTX_SCOPED_RANGE(nvtx_setup, "SETUP_PREFILL_EXECUTION", nvtx_colors::PALE_GREEN);

    if (mBaseEngineConfig.maxSupportedLoraRank > 0 && !mBaseEngineRunner->switchLoraWeights(context.loraWeightsName))
    {
        LOG_ERROR("Failed to switch LoRA weights to %s", context.loraWeightsName.c_str());
        return false;
    }

    int32_t const activeBatchSize = context.activeBatchSize;
    std::vector<std::vector<int32_t>> const& batchedInputIds = context.rawBatchedInputIds;
    rt::HybridCacheManager& baseCacheManager = mBaseEngineRunner->getCacheManager();

    // Record the length of the reused KVCache for each sequence using pre-allocated tensor.
    // Use activeBatchSize (actual request size)
    check::check(mHostReuseKVCacheLengths.reshape({activeBatchSize}), "Tensor reshape failed");
    int32_t* reuseKVCacheLengthsData = mHostReuseKVCacheLengths.dataPointer<int32_t>();

    // Initialize reuse lengths to 0 for all active sequences
    std::fill(reuseKVCacheLengthsData, reuseKVCacheLengthsData + activeBatchSize, 0);

    // Initialize tokenIds and effectivePrefillLengths for each sequence
    context.tokenIds.clear();
    context.tokenIds.resize(activeBatchSize);

    // Search if the system prompt has been cached. If there are cached system prompts, insert
    // the pre-computed KVCache and remove the cached portion from inputIds.
    // Directly populate context.tokenIds and context.effectivePrefillLengths (no padding)
    for (int32_t i = 0; i < activeBatchSize; ++i)
    {
        auto const& prompt = context.systemPrompts[i];
        auto const promptKey = keySystemPromptWithLoraWeights(prompt, context.loraWeightsName);
        if (mSystemPromptKVCacheBase.count(promptKey) > 0)
        {
            auto& precachedKVCacheBase = mSystemPromptKVCacheBase[promptKey];
            baseCacheManager.restoreKVCache(precachedKVCacheBase.kvCacheLayers, i, context.stream);

            if (mDraftEngineRunner != nullptr)
            {
                check::check(mSystemPromptKVCacheDraft.count(promptKey) > 0,
                    "System prompt cache inconsistency between base and draft model");
                auto& precachedKVCacheDraft = mSystemPromptKVCacheDraft[promptKey];
                mDraftEngineRunner->getCacheManager().restoreKVCache(
                    precachedKVCacheDraft.kvCacheLayers, i, context.stream);
            }

            // Restore recurrent states if applicable
            if (mBaseEngineConfig.numLinearAttnLayers > 0)
            {
                restoreRecurrentStates(i, precachedKVCacheBase, context.stream);
            }

            auto reuseLength = math::cast<size_t>(precachedKVCacheBase.kvCacheLayers[0].getShape()[2]);
            // If the system prompt is not well designed, the boundary of the inputIDs could be mis-aligned.
            check::check(reuseLength > 0 && reuseLength < batchedInputIds[i].size(),
                "The reuse length shall be larger than 0 and not exceed the input length.");
            // Reuse N-1 tokens from the cached prefix so the Nth token is treated as real input in prefill;
            // this keeps the draft prefill boundary aligned with the true next-token position.
            auto const effectiveReuseLength = reuseLength - 1;
            reuseKVCacheLengthsData[i] = math::cast<int32_t>(effectiveReuseLength);

            // Directly assign to context.tokenIds (skip only the reused portion, keep the next token for normal flow)
            context.tokenIds[i].assign(batchedInputIds[i].begin() + effectiveReuseLength, batchedInputIds[i].end());
            context.effectivePrefillLengths[i] = math::cast<int32_t>(batchedInputIds[i].size() - effectiveReuseLength);

            bool const matchIds = std::equal(precachedKVCacheBase.tokenizedPrompt.begin(),
                precachedKVCacheBase.tokenizedPrompt.end(), batchedInputIds[i].begin());
            if (!matchIds)
            {
                LOG_WARNING(
                    "Though system prompt strings are matched, token_ids are not perfectly aligned."
                    "This may generate incorrect result, please check your system prompt design.");
            }
        }
        else
        {
            // Directly assign to context.tokenIds (full input)
            context.tokenIds[i] = batchedInputIds[i];
            context.effectivePrefillLengths[i] = static_cast<int32_t>(batchedInputIds[i].size());
            reuseKVCacheLengthsData[i] = 0;

            // Zero recurrent states for batches without cache hit
            if (mBaseEngineConfig.numLinearAttnLayers > 0)
            {
                zeroRecurrentStates(i, context.stream);
            }
        }
    }

    // Validate max input length
    int32_t const maxInputLength
        = *std::max_element(context.effectivePrefillLengths.begin(), context.effectivePrefillLengths.end());
    if (maxInputLength > mBaseEngineConfig.maxSupportedInputLength)
    {
        LOG_ERROR("The max input length (%d) exceeds the max supported input length (%d) of the LLM Engine.",
            maxInputLength, mBaseEngineConfig.maxSupportedInputLength);
        return false;
    }

    baseCacheManager.resetForNewSequences(mHostReuseKVCacheLengths, context.stream);

    if (mDraftEngineRunner != nullptr)
    {
        mDraftEngineRunner->getCacheManager().resetForNewSequences(mHostReuseKVCacheLengths, context.stream);
    }

    return true;
}

bool LLMInferenceSpecDecodeRuntime::genAndSaveSystemPromptKVCache(
    SpecDecodeInferenceContext& context, int32_t genAndSaveBatchIdx)
{
    std::string const& loraWeightsName = context.loraWeightsName;
    // Check if cache already exists
    std::string const prompt = context.systemPrompts[genAndSaveBatchIdx];
    auto const promptKey = keySystemPromptWithLoraWeights(prompt, loraWeightsName);

    if (prompt.empty())
    {
        LOG_DEBUG("The systemPrompt is empty. Skip saving system prompt KVCache.");
        return true;
    }

    if (mSystemPromptKVCacheBase.find(promptKey) != mSystemPromptKVCacheBase.end()
        && (!mDraftEngineRunner || mSystemPromptKVCacheDraft.find(promptKey) != mSystemPromptKVCacheDraft.end()))
    {
        LOG_DEBUG("The system prompt KVCache already exists for the prompt: {%s}", prompt.c_str());
        return true;
    }

    auto tokenizedPrompt = mTokenizer->encode(prompt, true);
    if (tokenizedPrompt.empty())
    {
        LOG_ERROR("Failed to encode system prompt for KVCache generation.");
        return false;
    }
    int32_t const promptIdsLength = static_cast<int32_t>(tokenizedPrompt.size());

    if (promptIdsLength > mBaseEngineConfig.maxSupportedInputLength)
    {
        LOG_ERROR("System prompt length (%d) exceeds max supported input length (base=%d)", promptIdsLength,
            mBaseEngineConfig.maxSupportedInputLength);
        return false;
    }

    if (mDraftEngineConfig.has_value() && promptIdsLength > mDraftEngineConfig->maxSupportedInputLength)
    {
        LOG_ERROR("System prompt length (%d) exceeds max supported input length (draft=%d)", promptIdsLength,
            mDraftEngineConfig->maxSupportedInputLength);
        return false;
    }

    // Create a temporary single-batch context for system prompt KVCache generation
    // Reuse the existing prefill functions which will use runtime member tensors (mIdsInput, mLogitsOutput, etc.)
    SpecDecodeInferenceContext tempContext;
    // Generate with batch size 1 and generate length 1 (prefill only).
    tempContext.initialize(1, 1, context.visualEmbeddings, context.deepstackFeatures, loraWeightsName, context.stream);
    tempContext.systemPrompts[0] = prompt;
    tempContext.rawBatchedInputIds.push_back(tokenizedPrompt);
    tempContext.tokenIds[0] = tokenizedPrompt;

    // Setup for prefill execution: handles KV cache reset and applies any reused system prompt cache
    if (!setUpForPrefillExecution(tempContext))
    {
        LOG_ERROR("Prefill execution setup failed for system prompt KVCache generation.");
        return false;
    }

    // Run base model prefill (reuses mIdsInput, mLogitsOutput, mBaseHiddenStatesOutput)
    bool prefillStatus = runBaseModelPrefill(tempContext);
    if (!prefillStatus)
    {
        LOG_ERROR("Failed to execute base model prefill for system prompt KVCache generation.");
        return false;
    }

    // Tokens produced during system KV-cache reuse prefill do not count as generated tokens
    tempContext.currentGenerateLengths[0] -= 1;

    // Run draft model prefill if draft model is present
    if (mDraftEngineRunner != nullptr)
    {
        bool draftPrefillStatus = runDraftModelPrefill(tempContext);
        if (!draftPrefillStatus)
        {
            LOG_ERROR("Failed to execute draft model prefill for system prompt KVCache generation.");
            return false;
        }
    }
    CUDA_CHECK(cudaStreamSynchronize(context.stream));

    // Copy out the KVCache content from the prefill step
    auto& baseCacheManager = mBaseEngineRunner->getCacheManager();

    SystemPromptKVCache savedKVCacheBase;
    savedKVCacheBase.systemPrompt = prompt;
    savedKVCacheBase.tokenizedPrompt = tokenizedPrompt;

    // We only process one sequence at a time
    constexpr int32_t CACHE_BATCH_IDX{0};
    savedKVCacheBase.kvCacheLayers = baseCacheManager.captureKVCache(CACHE_BATCH_IDX, promptIdsLength, context.stream);

    // Save recurrent and conv states for hybrid layers
    if (mBaseEngineConfig.numLinearAttnLayers > 0)
    {
        savedKVCacheBase.recurrentStateContents
            = baseCacheManager.captureRecurrentStates(CACHE_BATCH_IDX, context.stream);
        savedKVCacheBase.convStateContents = baseCacheManager.captureConvStates(CACHE_BATCH_IDX, context.stream);
    }

    mSystemPromptKVCacheBase.insert({promptKey, std::move(savedKVCacheBase)});

    if (mDraftEngineRunner != nullptr)
    {
        auto& draftCacheManager = mDraftEngineRunner->getCacheManager();

        SystemPromptKVCache savedKVCacheDraft;
        savedKVCacheDraft.systemPrompt = prompt;
        savedKVCacheDraft.tokenizedPrompt = tokenizedPrompt;
        savedKVCacheDraft.kvCacheLayers
            = draftCacheManager.captureKVCache(CACHE_BATCH_IDX, promptIdsLength, context.stream);
        mSystemPromptKVCacheDraft.insert({promptKey, std::move(savedKVCacheDraft)});
    }

    CUDA_CHECK(cudaStreamSynchronize(context.stream));
    LOG_DEBUG("System prompt KVCache saved for batch %d: {%s}", genAndSaveBatchIdx, prompt.c_str());

    return true;
}

bool LLMInferenceSpecDecodeRuntime::genAndSaveSystemPromptKVCache(
    std::string const& prompt, std::string const& loraWeightsName, cudaStream_t stream)
{
    if (prompt.empty())
    {
        LOG_DEBUG("The systemPrompt is empty. Skip saving system prompt KVCache.");
        return true;
    }
    auto const promptKey = keySystemPromptWithLoraWeights(prompt, loraWeightsName);
    if (mSystemPromptKVCacheBase.find(promptKey) != mSystemPromptKVCacheBase.end())
    {
        LOG_DEBUG("The system prompt KVCache already exists for the prompt: {%s}", prompt.c_str());
        return true;
    }
    SpecDecodeInferenceContext tempContext;
    tempContext.initialize(1, 1, std::nullopt, rt::OptionalInputTensors{}, loraWeightsName, stream);
    tempContext.systemPrompts[0] = prompt;
    auto tokenizedPrompt = mTokenizer->encode(prompt, true);
    if (tokenizedPrompt.empty())
    {
        LOG_ERROR("Failed to encode system prompt for KVCache generation.");
        return false;
    }
    tempContext.rawBatchedInputIds.push_back(tokenizedPrompt);
    tempContext.tokenIds[0] = tokenizedPrompt;
    return genAndSaveSystemPromptKVCache(tempContext, 0);
}

bool LLMInferenceSpecDecodeRuntime::performBatchEvict(SpecDecodeInferenceContext& context)
{
    // Check if any batch has finished
    bool hasFinishedBatch = false;
    for (int32_t i = 0; i < context.activeBatchSize; ++i)
    {
        if (context.finishedStates[i])
        {
            hasFinishedBatch = true;
            break;
        }
    }

    if (!hasFinishedBatch)
    {
        return true;
    }

    int32_t const oldActiveBatch = context.activeBatchSize;

    // Build batch mapping
    std::vector<int32_t> batchMapping = buildBatchMapping(context.finishedStates);

    // Calculate new active batch size
    int32_t newActiveBatch = 0;
    for (auto newIdx : batchMapping)
    {
        if (newIdx >= 0)
        {
            newActiveBatch = std::max(newActiveBatch, newIdx + 1);
        }
    }

    // Log eviction details
    std::vector<int32_t> evictedIndices;
    for (int32_t i = 0; i < oldActiveBatch; ++i)
    {
        if (batchMapping[i] < 0)
        {
            evictedIndices.push_back(i);
        }
    }
    LOG_DEBUG("Batch eviction: %d active batches to %d remaining (evicted %d batch(es): indices [%s])", oldActiveBatch,
        newActiveBatch, static_cast<int32_t>(evictedIndices.size()),
        [&evictedIndices]() {
            std::string result;
            for (size_t i = 0; i < evictedIndices.size(); ++i)
            {
                if (i > 0)
                {
                    result += ", ";
                }
                result += std::to_string(evictedIndices[i]);
            }
            return result;
        }()
            .c_str());

    // Upload batch mapping to GPU
    check::check(mDeviceBatchMapping.reshape({oldActiveBatch}), "Tensor reshape failed");
    CUDA_CHECK(cudaMemcpyAsync(mDeviceBatchMapping.rawPointer(), batchMapping.data(), oldActiveBatch * sizeof(int32_t),
        cudaMemcpyHostToDevice, context.stream));

    // Compact Base KV Cache
    mBaseEngineRunner->getCacheManager().compactBatch(
        mDeviceBatchMapping, oldActiveBatch, newActiveBatch, context.stream);
    mBaseEngineRunner->getCacheManager().setActiveBatchSize(newActiveBatch);

    // Compact Draft KV Cache (only when draft model is present)
    if (mDraftEngineRunner != nullptr)
    {
        mDraftEngineRunner->getCacheManager().compactBatch(
            mDeviceBatchMapping, oldActiveBatch, newActiveBatch, context.stream);
        mDraftEngineRunner->getCacheManager().setActiveBatchSize(newActiveBatch);

        // Compact Draft Model's RoPE CosSin Cache if it's per-batch (MRope for multimodal)
        rt::Tensor& draftRopeCache = mDraftEngineRunner->getRopeCosSinCacheTensor();
        if (draftRopeCache.getShape().getNumDims() == 3 && draftRopeCache.getShape()[0] == oldActiveBatch
            && newActiveBatch > 0)
        {
            kernel::compactTensorBatch(
                draftRopeCache, mDeviceBatchMapping, draftRopeCache, oldActiveBatch, newActiveBatch, context.stream);
            auto const seqLen = static_cast<int32_t>(draftRopeCache.getShape()[1]);
            auto const rotaryDim = static_cast<int32_t>(draftRopeCache.getShape()[2]);
            check::check(draftRopeCache.reshape({newActiveBatch, seqLen, rotaryDim}), "Tensor reshape failed");
        }
    }

    // Compact Base Model's RoPE CosSin Cache if it's per-batch (MRope for multimodal)
    rt::Tensor& baseRopeCache = mBaseEngineRunner->getRopeCosSinCacheTensor();
    if (baseRopeCache.getShape().getNumDims() == 3 && baseRopeCache.getShape()[0] == oldActiveBatch
        && newActiveBatch > 0)
    {
        kernel::compactTensorBatch(
            baseRopeCache, mDeviceBatchMapping, baseRopeCache, oldActiveBatch, newActiveBatch, context.stream);
        auto const seqLen = static_cast<int32_t>(baseRopeCache.getShape()[1]);
        auto const rotaryDim = static_cast<int32_t>(baseRopeCache.getShape()[2]);
        check::check(baseRopeCache.reshape({newActiveBatch, seqLen, rotaryDim}), "Tensor reshape failed");
    }

    // Compact cross-round GPU tensors that are read (not just written) in the next round

    // 1. mBaseHiddenStatesOutput: read by runDraftModelAcceptToken in next round
    //    Shape: [activeBatchSize, maxAcceptDepth, baseHiddenDim]
    if (mBaseHiddenStatesOutput.getShape().getNumDims() == 3 && mBaseHiddenStatesOutput.getShape()[0] == oldActiveBatch
        && newActiveBatch > 0)
    {
        kernel::compactTensorBatch(mBaseHiddenStatesOutput, mDeviceBatchMapping, mBaseHiddenStatesOutput,
            oldActiveBatch, newActiveBatch, context.stream);
        auto const dim1 = static_cast<int32_t>(mBaseHiddenStatesOutput.getShape()[1]);
        auto const dim2 = static_cast<int32_t>(mBaseHiddenStatesOutput.getShape()[2]);
        check::check(mBaseHiddenStatesOutput.reshape({newActiveBatch, dim1, dim2}), "Tensor reshape failed");
    }

    // 2. mAcceptedTokenIds: read by runDraftModelAcceptToken to prepare input IDs (draft-only)
    //    Shape: [activeBatchSize, maxAcceptDepth]
    if (mDraftEngineRunner != nullptr)
    {
        if (mAcceptedTokenIds.getShape()[0] == oldActiveBatch && newActiveBatch > 0)
        {
            kernel::compactTensorBatch(mAcceptedTokenIds, mDeviceBatchMapping, mAcceptedTokenIds, oldActiveBatch,
                newActiveBatch, context.stream);
            auto const maxAcceptDepth = static_cast<int32_t>(mAcceptedTokenIds.getShape()[1]);
            check::check(mAcceptedTokenIds.reshape({newActiveBatch, maxAcceptDepth}), "Tensor reshape failed");
        }

        // 3. mAcceptLength: read by runDraftModelAcceptToken to set per-batch accept counts (draft-only)
        //    Shape: [activeBatchSize]
        if (mAcceptLength.getShape()[0] == oldActiveBatch && newActiveBatch > 0)
        {
            kernel::compactTensorBatch(
                mAcceptLength, mDeviceBatchMapping, mAcceptLength, oldActiveBatch, newActiveBatch, context.stream);
            check::check(mAcceptLength.reshape({newActiveBatch}), "Tensor reshape failed");
        }
    }

    // Compact CPU context
    CUDA_CHECK(cudaStreamSynchronize(context.stream));

    // Save evicted batches' results before compacting (using original batch index)
    for (size_t i = 0; i < batchMapping.size(); ++i)
    {
        if (batchMapping[i] < 0 && context.finishedStates[i])
        {
            // This batch is evicted and finished, save its results with original index
            int32_t originalIdx = context.batchIndexMapping[i];

            // Create and populate BatchResult with all related data
            BatchResult result;
            result.tokenIds = std::move(context.tokenIds[i]);
            result.generateLength = context.currentGenerateLengths[i];
            result.actualIterations = context.generationRound;
            result.rawBatchedInputIds = std::move(context.rawBatchedInputIds[i]);
            result.effectivePrefillLength = context.effectivePrefillLengths[i];

            context.completedBatches[originalIdx] = std::move(result);
        }
    }

    rt::compactVector(batchMapping, context.finishedStates);
    rt::compactVector(batchMapping, context.currentGenerateLengths);
    rt::compactVector(batchMapping, context.tokenIds);
    rt::compactVector(batchMapping, context.systemPrompts);
    rt::compactVector(batchMapping, context.rawBatchedInputIds);
    rt::compactVector(batchMapping, context.effectivePrefillLengths);
    rt::compactVector(batchMapping, context.batchIndexMapping);
    rt::compactVector(batchMapping, context.slotStreams);

    // Update active batch size
    context.activeBatchSize = newActiveBatch;

    return true;
}

} // namespace rt
} // namespace trt_edgellm
