/*
 * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "multimodalRunner.h"
#include <cuda_fp16.h>
#include <vector>

namespace trt_edgellm
{
namespace rt
{

//! \brief Configuration for Qwen-VL vision encoder
struct QwenViTConfig
{
    int64_t maxHW{0};         //!< Maximum height * width
    int64_t minHW{0};         //!< Minimum height * width
    int64_t inputDim{0};      //!< Input dimension
    int64_t vitPosEmbDim{0};  //!< Vision transformer position embedding dimension
    int64_t outHiddenSize{0}; //!< Output hidden dimension size

    int32_t vocabSize{0};              //!< Vocabulary size
    int32_t visionStartTokenId{0};     //!< Token ID for vision start
    int32_t visionEndTokenId{0};       //!< Token ID for vision end
    int32_t imageTokenId{0};           //!< Token ID for image placeholder
    int32_t videoTokenId{0};           //!< Token ID for video placeholder
    float mropeTheta{0};               //!< Multi-dimensional RoPE theta parameter
    int64_t patchSize{0};              //!< Patch size in pixels
    int64_t temporalPatchSize{0};      //!< Temporal patch size for video
    int64_t mergeSize{0};              //!< Merge size for patches
    int64_t windowSize{0};             //!< Window attention size used by Qwen2.5-VL
    int64_t numGridPerSide{0};         //!< Number of grid per side for fast position embedding used by Qwen3-VL
    int64_t numDeepstackFeatures{0};   //!< Number of deepstack features for Qwen3-VL
    int32_t mropeSectionH{0};          //!< MRoPE section: number of frequency pairs for height
    int32_t mropeSectionW{0};          //!< MRoPE section: number of frequency pairs for width
    int64_t minImageTokensPerImage{0}; //!< Minimum image tokens generated by each image. Used for resizing.
    int64_t maxImageTokensPerImage{0}; //!< Maximum image tokens generated by each image. Used for resizing.
    int64_t maxNumImages{0};           //!< Maximum number of images per request. Used for pre-allocation.

    std::vector<float> imageMean{}; //!< Image normalization mean values (RGB)
    std::vector<float> imageStd{};  //!< Image normalization standard deviation values (RGB)
};

//! \brief Runner for Qwen-VL vision encoder
//!
//! This class handles the preprocessing and inference of Qwen-VL vision encoder,
class QwenViTRunner : public MultimodalRunner
{
public:
    //! \brief Constructor for QwenViTRunner
    //! \param[in] engineDir Directory containing the TensorRT engine files
    //! \param[in] llmMaxBatchSize Maximum batch size from LLM engine
    //! \param[in] llmMaxSequenceLength Maximum sequence length from LLM engine
    //! \param[in] stream CUDA stream for execution
    //! \throws std::runtime_error if engine directory does not contain engine files, or if buffer allocation fails
    //! \throws json::type_error if JSON configuration contains unexpected datatypes
    QwenViTRunner(
        std::string const& engineDir, int32_t llmMaxBatchSize, int32_t llmMaxSequenceLength, cudaStream_t stream);

    ~QwenViTRunner() noexcept = default;

    //! \brief Preprocess multimodal input including images and text
    //! \param[in] request LLM generation request containing images and text
    //! \param[in,out] batchedInputIds Batched input token IDs after preprocessing
    //! \param[in] tokenizer Tokenizer for text processing
    //! \param[in,out] ropeRotaryCosSinDevice RoPE rotary position encoding cache
    //! \param[in] stream CUDA stream for execution
    //! \return True if preprocessing succeeded, false otherwise
    //! \throws std::runtime_error if sequence length is invalid, or a CUDA error occurs
    bool preprocess(rt::LLMGenerationRequest const& request, std::vector<std::vector<int32_t>>& batchedInputIds,
        tokenizer::Tokenizer const* tokenizer, rt::Tensor& ropeRotaryCosSinDevice, cudaStream_t stream,
        bool imageOnly = false) override;

    //! \brief Encode the system prompt and generate ND-RoPE parameters for the system prompt for KVCache saving.
    //! \param[in] systemPrompt System prompt string
    //! \param[in] tokenizer Tokenizer for text processing
    //! \param[in,out] ropeRotaryCosSinDevice RoPE rotary position encoding cache
    //! \param[in] stream CUDA stream for execution
    //! \return True if preprocessing succeeded, false otherwise
    bool preprocessSystemPrompt(std::string const& systemPrompt, tokenizer::Tokenizer const* tokenizer,
        rt::Tensor& ropeRotaryCosSinDevice, cudaStream_t stream) override;

    //! \brief Run inference on the vision encoder
    //! \param[in] stream CUDA stream for execution
    //! \return True if inference succeeded, false otherwise
    bool infer(cudaStream_t stream) noexcept override;

    //! \brief Validate and load configuration from JSON file
    //! \param[in] engineDir Path to engine directory
    //! \return True if configuration is valid and loaded successfully, false otherwise
    //! \throws json::type_error if JSON configuration contains unexpected datatypes
    bool validateAndFillConfig(std::string const& engineDir) override;

    //! \brief Allocate buffers for inference
    //! \return True if allocation succeeded, false otherwise
    //! \throws std::runtime_error if a CUDA operation fails
    bool allocateBuffer(cudaStream_t stream) override;

    //! \brief Get deepstack features for Qwen3-VL
    //! \return Optional input tensors vector containing deepstack features
    rt::OptionalInputTensors getDeepstackFeatures() override;

private:
    //! \brief Calculate resized image dimensions based on dynamic resolution constraints
    //! \param[in] height Input image height
    //! \param[in] width Input image width
    //! \param[in] maxRatio Maximum aspect ratio (default: 200)
    //! \return Tuple of (resized_height, resized_width)
    //! \throws std::runtime_error if aspect ratio is invalid
    std::tuple<int64_t, int64_t> getResizedImageSize(
        int64_t const height, int64_t const width, int64_t const maxRatio = 200);

    //! \brief Preprocess text portion of the request
    //! \param[in] request LLM generation request
    //! \param[out] batchInputIds Batch of input token IDs
    //! \param[in] numImages Number of images per request
    //! \param[in] imageTokenLengths Token lengths for each image
    //! \param[in] tokenizer Tokenizer for text processing
    //! \throws std::runtime_error if requests size incorrect
    void textPreprocess(rt::LLMGenerationRequest const& request, std::vector<std::vector<int32_t>>& batchInputIds,
        std::vector<int64_t> const& numImages, std::vector<int64_t> const& imageTokenLengths,
        trt_edgellm::tokenizer::Tokenizer const* tokenizer);

    //! \brief Compute window indices for window attention (Qwen2.5-VL)
    //! \param[in] imageGridTHWs Image grid dimensions (Temporal, Height, Width)
    //! \param[in] curHW Current height * width
    //! \param[in] stream CUDA stream for execution
    //! \throws std::runtime_error if image dimensions invalid
    void getWindowIndex(
        std::vector<std::vector<int64_t>> const& imageGridTHWs, int64_t const curHW, cudaStream_t stream);

    //! \brief Format and process a single image patch
    //! \param[in] image Input image data
    //! \param[out] imageGridTHWs Image grid dimensions for each image
    //! \param[out] imageTokenLengths Token lengths for each image
    //! \param[in,out] cuSeqlensData Pointer to cumulative sequence lengths data
    //! \param[in,out] cuSeqlensSize Reference to current size of cumulative sequence lengths
    //! \param[in,out] maxSeqLen Reference to current maximum sequence length in this request
    //! \param[in] stream CUDA stream for execution
    //! \throws std::runtime_error if image dimensions are incompatible with patch size, or sequence length is out of
    //! range
    //! \throws std::runtime_error if a CUDA error occurs
    void formatPatch(rt::imageUtils::ImageData const& image, std::vector<std::vector<int64_t>>& imageGridTHWs,
        std::vector<int64_t>& imageTokenLengths, int32_t* cuSeqlensData, int64_t& cuSeqlensSize, int64_t& maxSeqLen,
        cudaStream_t stream);

    //! \brief Get multi-dimensional RoPE position indices
    //! \param[in] batchInputIds Batch of input token IDs
    //! \param[in] imageGridTHWs Image grid dimensions (Temporal, Height, Width)
    void getMRopePositionIds(std::vector<std::vector<int32_t>> const& batchInputIds,
        std::vector<std::vector<int64_t>> const& imageGridTHWs) noexcept;

    //! \brief Generate multi-dimensional RoPE parameters
    //! \param[in] batchInputIds Batch of input token IDs
    //! \param[in] imageGridTHWs Image grid dimensions (Temporal, Height, Width)
    //! \param[in,out] ropeRotaryCosSinDevice RoPE rotary position encoding cache
    //! \param[in] stream CUDA stream for execution
    //! \throws std::runtime_error if shape validation fails, or a CUDA operation fails
    void generateMropeParams(std::vector<std::vector<int32_t>> const& batchInputIds,
        std::vector<std::vector<int64_t>> const& imageGridTHWs, rt::Tensor& ropeRotaryCosSinDevice,
        cudaStream_t stream);

    //! \brief Preprocess all images in the request
    //! \param[in] request LLM generation request containing images
    //! \param[out] imageGridTHWs Image grid dimensions for each image
    //! \param[out] imageTokenLengths Token lengths for each image
    //! \param[out] numImages Number of images per request
    //! \param[in] doResize Whether to resize images
    //! \param[in] stream CUDA stream for execution
    //! \throws std::runtime_error if aspect ratio is invalid
    //! \throws std::runtime_error if image dimensions are incompatible with patch size, or sequence length is out of
    //! range
    //! \throws std::runtime_error if a CUDA error occurs
    void imagePreprocess(rt::LLMGenerationRequest const& request, std::vector<std::vector<int64_t>>& imageGridTHWs,
        std::vector<int64_t>& imageTokenLengths, std::vector<int64_t>& numImages, bool doResize, cudaStream_t stream);

    QwenViTConfig mConfig{};                       //!< Qwen-VL configuration
    rt::Tensor mVitInput{};                        //!< Vision encoder input tensor
    rt::Tensor mRotaryPosEmb{};                    //!< Rotary position embeddings tensor (multi-dimensional RoPE)
    rt::Tensor mCuSeqlens{};                       //!< Cumulative sequence lengths tensor
    rt::Tensor mCuSeqlensHost{};                   //!< Cumulative sequence lengths host tensor
    rt::Tensor mMaxSeqLenCarrier{};                //!< Shape-only input carrying max sequence length for FMHA launch
    rt::Tensor mImageMean{};                       //!< Image mean tensor
    rt::Tensor mImageStd{};                        //!< Image standard deviation tensor
    rt::Tensor mImageDevice{};                     //!< Temporary image buffer for preprocessing
    rt::Tensor mNormalizedImageDevice{};           //!< Temporary normalized image buffer for preprocessing
    rt::imageUtils::ImageData mResizedImageHost{}; //!< Pre-allocated buffer for image resizing
    rt::Tensor mMropePositionIdsHost{};            //!< MRoPE position IDs host tensor
    rt::Tensor mMropePositionIdsDevice{};          //!< MRoPE position IDs device tensor
    // Qwen2.5-VL
    rt::Tensor mCuWindowSeqlens{};          //!< Cumulative window sequence lengths device tensor
    rt::Tensor mCuWindowSeqlensHost{};      //!< Cumulative window sequence lengths host tensor
    rt::Tensor mWindowIndexHost{};          //!< Window index host tensor for window attention
    rt::Tensor mWindowIndexDevice{};        //!< Window index device tensor for window attention
    rt::Tensor mReverseWindowIndexHost{};   //!< Reverse window index host tensor
    rt::Tensor mReverseWindowIndexDevice{}; //!< Reverse window index device tensor
    // Qwen3-VL
    rt::Tensor mFastPosEmbIdx{};                  //!< Fast position embeddings index tensor
    rt::Tensor mFastPosEmbWeight{};               //!< Fast position embeddings weight tensor
    std::vector<rt::Tensor> mDeepstackFeatures{}; //!< Deepstack features tensors

    int32_t mLLMMaxBatchSize{0};      //!< Maximum batch size from LLM engine
    int32_t mLLMMaxSequenceLength{0}; //!< Maximum sequence length from LLM engine

    std::vector<std::vector<int64_t>> mLastImageGridTHWs; //!< Used to determine whether RoPE can be reused.
};

} // namespace rt
} // namespace trt_edgellm
