# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict
from typing import Optional
import safetensors
from ..._utils import numpy_to_torch
from ...functional import ACT2FN, Tensor, concat, shape, slice
from ...layers import Linear
from ...logger import logger
from ...mapping import Mapping
from ...models import CLIPVisionTransformer
from ...module import Module
from ...parameter import Parameter
from ..model_weights_loader import ModelWeightsLoader
from ..modeling_utils import PretrainedModel, QuantConfig
from .config import LlavaNextVisionConfig
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/llava_next/modeling_llava_next.py#L149
class LlavaNextMultiModalProjector(Module):
def __init__(self, config: LlavaNextVisionConfig):
super().__init__()
self.linear_1 = Linear(config.hidden_size,
config.text_hidden_size,
dtype=config.dtype)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = Linear(config.text_hidden_size,
config.text_hidden_size,
dtype=config.dtype)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
[docs]
class LlavaNextVisionWrapper(PretrainedModel):
def __init__(self, config: LlavaNextVisionConfig):
super().__init__(config)
self.vision_tower = None
self.config = config
if config.vision_model_type == "clip_vision_model":
self.vision_tower = CLIPVisionTransformer(
image_size=config.image_size,
num_channels=config.num_channels,
patch_size=config.patch_size,
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
norm_epsilon=config.norm_epsilon,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
num_hidden_layers=config.num_hidden_layers,
require_ln_f=False,
mapping=config.mapping,
dtype=config.dtype)
else:
logger.error(
"Currently TRT-LLM only supports CLIP vision transformer.")
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
self.image_newline = Parameter(shape=(config.text_hidden_size, ),
dtype=config.dtype)
[docs]
def forward(self, pixel_values, position_ids=None):
image_features = self.vision_tower(pixel_values)
select_size = concat([
shape(image_features, 0), image_features.shape[1] - 1,
shape(image_features, 2)
])
selected_image_feature = slice(image_features,
starts=[0, 1, 0],
sizes=select_size) # (bs, 576, c)
image_features = self.multi_modal_projector(selected_image_feature)
image_features.mark_output('image_features', self.config.dtype)
return image_features # (bs, 576, c)
[docs]
@classmethod
def from_hugging_face(cls,
hf_model_dir: str,
dtype: str = 'auto',
mapping: Optional[Mapping] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs):
''' Create a LlavaNextVisionWrapper object from give parameters
'''
if os.environ.get("TRTLLM_DISABLE_UNIFIED_CONVERTER") is not None:
logger.error(
"Please enable unified converter to convert llava-next checkpoints."
)
config = LlavaNextVisionConfig.from_hugging_face(
hf_model_dir,
dtype=dtype,
mapping=mapping,
quant_config=quant_config,
**kwargs)
custom_dict = {}
if "llava" in hf_model_dir:
custom_dict = {
"vision_tower": "vision_tower.vision_model",
"input_layernorm": "layer_norm1",
"post_layernorm": "layer_norm2",
"fc": "fc1",
"proj": "fc2",
"dense": "out_proj",
"pre_layernorm": "pre_layrnorm",
"ln_f": "post_layernorm",
}
loader = ModelWeightsLoader(hf_model_dir, custom_dict)
model = cls(config)
loader.generate_tllm_weights(model)
return model
[docs]
def save_checkpoint(self, output_dir, save_config=True):
rank = self.config.mapping.rank
weights = {
name: numpy_to_torch(param.raw_value)
for name, param in self.named_parameters()
}
image_newline = {
"image_newline": numpy_to_torch(self.image_newline.raw_value)
}
safetensors.torch.save_file(
weights, os.path.join(output_dir, f'rank{rank}.safetensors'))
safetensors.torch.save_file(
image_newline,
os.path.join(output_dir, f'image_newlines.safetensors'))
if save_config:
self.config.to_json_file(os.path.join(output_dir, 'config.json'))