Source code for tensorrt_edgellm.onnx_export.visual_export

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visual model export functionality for TensorRT Edge-LLM.

This module provides functions to export visual components of multimodal models
to ONNX format with optional quantization support.
"""

import json
import os
from typing import Optional

import torch

from tensorrt_edgellm.quantization.visual_quantization import quantize_visual
# Import visual model wrappers
from tensorrt_edgellm.visual_models.internvl3_model import (
    InternVLVisionModel, export_internvl3_visual)
from tensorrt_edgellm.visual_models.phi4mm_model import (Phi4MMVisionModel,
                                                         export_phi4mm_visual)
from tensorrt_edgellm.visual_models.qwen2_5_vl_model import (
    Qwen2_5_VisionTransformerPretrainedModelPatch, export_qwen2_5_vl_visual)
from tensorrt_edgellm.visual_models.qwen2_vl_model import (
    Qwen2VisionTransformerPretrainedModelPatch, export_qwen2_vl_visual)
from tensorrt_edgellm.visual_models.qwen3_vl_model import (
    Qwen3VLVisionModelPatch, export_qwen3_vl_visual)

from ..llm_models.model_utils import load_hf_model
from .config_export import export_vision_config


def _export_qwen_visual(model, model_type: str, model_dir: str,
                        output_dir: str, torch_dtype: torch.dtype, device: str,
                        quantization: Optional[str], processor,
                        dataset_dir: str) -> None:
    """
    Export visual models in the Qwen family (qwen2/qwen2.5/qwen3/qwen3-omni).
    """
    visual_model = model.visual if model_type != 'qwen3_omni' else model.thinker.visual

    # Quantize the original visual model first
    if quantization == "fp8":
        # Qwen2.5-VL 3B VIT has overflow issue with FP16 and only happens in /blocks.31/mlp/down_proj
        # Apply workaround to cast /blocks.31/mlp/down_proj to FP32 to avoid overflow in quantization.
        if model_type == 'qwen2_5_vl' and visual_model.config.out_hidden_size == 2048:
            from tensorrt_edgellm.visual_models.qwen2_5_vl_model import (
                Qwen2_5_VLMLPPatch, Qwen2_5_VLPatchMergerWAR)
            last_block = visual_model.blocks[-1]
            last_block.mlp = Qwen2_5_VLMLPPatch(visual_model.config,
                                                last_block.mlp)
            visual_model.merger = Qwen2_5_VLPatchMergerWAR(
                visual_model.config, visual_model.merger)

        visual_model = quantize_visual(visual_model, quantization, processor,
                                       dataset_dir)

    print(f"Exporting {model_type} visual model from {model_dir}")

    if model_type == 'qwen2_vl':
        wrapped_model = Qwen2VisionTransformerPretrainedModelPatch(
            visual_model)
        wrapped_model.eval().to(device)
        export_qwen2_vl_visual(wrapped_model, output_dir, torch_dtype)
    elif model_type == 'qwen2_5_vl':
        wrapped_model = Qwen2_5_VisionTransformerPretrainedModelPatch(
            visual_model)
        wrapped_model.eval().to(device)
        export_qwen2_5_vl_visual(wrapped_model, output_dir, torch_dtype)
    else:
        # qwen3_vl and qwen3_omni share the same visual wrapper/export path
        wrapped_model = Qwen3VLVisionModelPatch(visual_model)
        wrapped_model.eval().to(device)
        export_qwen3_vl_visual(wrapped_model, output_dir, torch_dtype)


[docs] def visual_export(model_dir: str, output_dir: str, dtype: str, quantization: Optional[str], dataset_dir: Optional[str] = "lmms-lab/MMMU", device: str = "cuda") -> str: """ Export visual model using the appropriate wrapper based on model architecture. This function loads a multimodal model, extracts its visual component, wraps it in the appropriate model wrapper, applies quantization if requested, and exports it to ONNX format. Args: model_dir: Directory containing the torch model output_dir: Directory to save the exported ONNX model dtype: Data type for export (currently only "fp16" supported) quantization: Quantization type ("fp8" or None) device: Device to load the model on (default: "cuda", options: cpu, cuda, cuda:0, cuda:1, etc.) Returns: str: Path to the output directory where the exported model is saved Raises: ValueError: If unsupported dtype or quantization is provided ValueError: If unsupported model type is detected """ # Convert dtype string to torch dtype if dtype == "fp16": torch_dtype = torch.float16 else: raise ValueError(f"Unsupported dtype: {dtype}") # Validate quantization mode assert quantization in [ "fp8", None ], f"Only fp8 or None is supported for quantization. You passed: {quantization}" # Load the model and processor try: model, _, processor = load_hf_model(model_dir, dtype, device) except Exception as e: raise ValueError(f"Could not load model from {model_dir}. Error: {e}") # Get visual model from the multimodal model model_type = model.config.model_type # Create output directory os.makedirs(output_dir, exist_ok=True) if model_type in ['qwen2_vl', 'qwen2_5_vl', 'qwen3_vl', 'qwen3_omni']: _export_qwen_visual(model, model_type, model_dir, output_dir, torch_dtype=torch_dtype, device=device, quantization=quantization, processor=processor, dataset_dir=dataset_dir) elif model_type == 'internvl': print(f"Exporting InternVL3 visual model from {model_dir}") # Create InternVL3 wrapper model wrapped_model = InternVLVisionModel(model) wrapped_model.eval().to(device) # Apply quantization to wrapped model if requested if quantization == "fp8": wrapped_model = quantize_visual(wrapped_model, quantization, processor.image_processor, dataset_dir) # Export using the wrapper's export function export_internvl3_visual(wrapped_model, output_dir, torch_dtype) elif model_type == 'phi4mm': print(f"Exporting Phi4MM visual model from {model_dir}") # Create Phi4MM wrapper model wrapped_model = Phi4MMVisionModel(model) wrapped_model.eval().to(device) # Apply quantization to wrapped model if requested if quantization == "fp8": wrapped_model = quantize_visual(wrapped_model, quantization, processor.image_processor, dataset_dir) export_phi4mm_visual(wrapped_model, output_dir, torch_dtype) else: raise ValueError(f"Unsupported model type: {model_type}") # Export model configuration to JSON config_dict = export_vision_config( model.config.thinker_config if model_type == 'qwen3_omni' else model.config) with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config_dict, f, indent=2) # Export processor configuration to JSON if exists if processor is not None: # Phi4MMProcessor may not define audio_tokenizer, but transformers' # save_pretrained expects the attribute. if not hasattr(processor, "audio_tokenizer"): processor.audio_tokenizer = None processor.save_pretrained(output_dir) print( f"Visual export completed for {model_type} with dtype={dtype}, quantization={quantization}, device={device}" ) print(f"Exported to: {output_dir}") return output_dir