Source code for tensorrt_llm.models.multimodal_encoders.model

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict
from typing import Optional

import safetensors

from ..._utils import numpy_to_torch
from ...functional import ACT2FN, Tensor, concat, shape, slice
from ...layers import Linear
from ...logger import logger
from ...mapping import Mapping
from ...models import CLIPVisionTransformer
from ...module import Module
from ...parameter import Parameter
from ..model_weights_loader import ModelWeightsLoader
from ..modeling_utils import PretrainedModel, QuantConfig
from .config import LlavaNextVisionConfig


# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/llava_next/modeling_llava_next.py#L149
class LlavaNextMultiModalProjector(Module):

    def __init__(self, config: LlavaNextVisionConfig):
        super().__init__()

        self.linear_1 = Linear(config.hidden_size,
                               config.text_hidden_size,
                               dtype=config.dtype)
        self.act = ACT2FN[config.projector_hidden_act]
        self.linear_2 = Linear(config.text_hidden_size,
                               config.text_hidden_size,
                               dtype=config.dtype)

    def forward(self, image_features):
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


[docs] class LlavaNextVisionWrapper(PretrainedModel): def __init__(self, config: LlavaNextVisionConfig): super().__init__(config) self.vision_tower = None self.config = config if config.vision_model_type == "clip_vision_model": self.vision_tower = CLIPVisionTransformer( image_size=config.image_size, num_channels=config.num_channels, patch_size=config.patch_size, hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, max_position_embeddings=config.max_position_embeddings, norm_epsilon=config.norm_epsilon, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, num_hidden_layers=config.num_hidden_layers, require_ln_f=False, mapping=config.mapping, dtype=config.dtype) else: logger.error( "Currently TRT-LLM only supports CLIP vision transformer.") self.multi_modal_projector = LlavaNextMultiModalProjector(config) self.image_newline = Parameter(shape=(config.text_hidden_size, ), dtype=config.dtype)
[docs] def forward(self, pixel_values, position_ids=None): image_features = self.vision_tower(pixel_values) select_size = concat([ shape(image_features, 0), image_features.shape[1] - 1, shape(image_features, 2) ]) selected_image_feature = slice(image_features, starts=[0, 1, 0], sizes=select_size) # (bs, 576, c) image_features = self.multi_modal_projector(selected_image_feature) image_features.mark_output('image_features', self.config.dtype) return image_features # (bs, 576, c)
[docs] @classmethod def from_hugging_face(cls, hf_model_dir: str, dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): ''' Create a LlavaNextVisionWrapper object from give parameters ''' if os.environ.get("TRTLLM_DISABLE_UNIFIED_CONVERTER") is not None: logger.error( "Please enable unified converter to convert llava-next checkpoints." ) config = LlavaNextVisionConfig.from_hugging_face( hf_model_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **kwargs) custom_dict = {} if "llava" in hf_model_dir: custom_dict = { "vision_tower": "vision_tower.vision_model", "input_layernorm": "layer_norm1", "post_layernorm": "layer_norm2", "fc": "fc1", "proj": "fc2", "dense": "out_proj", "pre_layernorm": "pre_layrnorm", "ln_f": "post_layernorm", } loader = ModelWeightsLoader(hf_model_dir, custom_dict) model = cls(config) loader.generate_tllm_weights(model) return model
[docs] def save_checkpoint(self, output_dir, save_config=True): rank = self.config.mapping.rank weights = { name: numpy_to_torch(param.raw_value) for name, param in self.named_parameters() } image_newline = { "image_newline": numpy_to_torch(self.image_newline.raw_value) } safetensors.torch.save_file( weights, os.path.join(output_dir, f'rank{rank}.safetensors')) safetensors.torch.save_file( image_newline, os.path.join(output_dir, f'image_newlines.safetensors')) if save_config: self.config.to_json_file(os.path.join(output_dir, 'config.json'))
[docs] def prepare_inputs(self, max_batch_size, **kwargs): '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the ranges of the dimensions of when using TRT dynamic shapes. @return: a list contains values which can be fed into the self.forward() ''' batch_size_range = [ 1, max(1, (max_batch_size + 1) // 2), max_batch_size ] pixel_values = Tensor( name='pixel_values', dtype=self.config.dtype, shape=[ -1, self.config.num_channels, self.config.image_size, self.config.image_size ], dim_range=OrderedDict([ ('batch_size', [batch_size_range]), ('in_channels', [[self.config.num_channels] * 3]), ('latent_height', [[self.config.image_size] * 3]), ('latent_width', [[self.config.image_size] * 3]), ])) return {'pixel_values': pixel_values}