Source code for tensorrt_llm.models.gemma.model

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Any, Dict, Optional

from tensorrt_llm.models.gemma.convert import (QuantizeModifiers, Weights,
                                               load_gemma_weights_from_hf_model,
                                               non_modelopt_quantize_if_needed)
from tensorrt_llm.quantization.mode import (MODELOPT_FLOW_QUANTIZATIONS,
                                            QuantAlgo)

from ..._common import default_net
from ..._utils import pad_vocab_size
from ...functional import (AllReduceFusionOp, AllReduceParams, LayerNormType,
                           Tensor, cast, recv, send)
from ...layers import (Attention, AttentionMaskType, AttentionParams,
                       ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams,
                       LoraParams, PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
                              QuantConfig, save_checkpoint, save_config)
from .config import GemmaConfig

if TYPE_CHECKING:

    from .config import HfConfigOrDir


class GemmaDecoderLayer(Module):

    def __init__(self, config: GemmaConfig, layer_idx: int):
        super().__init__()
        self.layer_idx = layer_idx
        self.config = config

        self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                                       eps=config.norm_epsilon,
                                       dtype=config.dtype)

        layers_range = config.mapping.pp_layers(config.num_hidden_layers)
        self.local_layer_idx = layer_idx - layers_range[0]

        q_scaling = 1.0
        max_attn_value = 0.0
        qk_layernorm = False
        is_sliding = False
        rotary_base = config.rotary_base
        rotary_base_local = None

        gemma2_config = config.gemma2_config()
        gemma3_config = config.gemma3_config()
        if gemma2_config:
            q_scaling = math.sqrt(
                gemma2_config.query_pre_attn_scalar) / math.sqrt(
                    config.head_size)
            max_attn_value = config.attn_logit_softcapping or 0.0
        elif gemma3_config:
            qk_layernorm = True
            q_scaling = math.sqrt(
                gemma3_config.query_pre_attn_scalar) / math.sqrt(
                    config.head_size)
            is_sliding = bool(
                (layer_idx + 1) % gemma3_config.sliding_window_pattern)
            rotary_base_local = config.rope_local_base_freq

        self.attention = Attention(
            local_layer_idx=self.local_layer_idx,
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            attention_head_size=config.head_size,
            qk_layernorm=qk_layernorm,
            layernorm_type=LayerNormType.RmsNorm,
            max_position_embeddings=config.max_position_embeddings,
            dtype=config.dtype,
            attention_mask_type=AttentionMaskType.causal,
            bias=config.attn_bias,
            position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
            rotary_embedding_base=rotary_base,
            rotary_embedding_base_local=rotary_base_local,
            rotary_embedding_scaling=config.rotary_scaling,
            tp_group=config.mapping.tp_group,
            tp_size=config.mapping.tp_size,
            quant_mode=config.quant_mode,
            q_scaling=q_scaling,
            max_attn_value=max_attn_value,
            is_local=is_sliding,
        )

        mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size

        self.mlp = GatedMLP(hidden_size=config.hidden_size,
                            ffn_hidden_size=mlp_hidden_size,
                            hidden_act=config.hidden_act,
                            dtype=config.dtype,
                            bias=config.mlp_bias,
                            tp_group=config.mapping.tp_group,
                            tp_size=config.mapping.tp_size,
                            quant_mode=config.quant_mode)

        if self.config.inter_layernorms:
            self.pre_feedforward_layernorm = RmsNorm(
                normalized_shape=config.hidden_size,
                eps=config.norm_epsilon,
                dtype=config.dtype)
            self.post_feedforward_layernorm = RmsNorm(
                normalized_shape=config.hidden_size,
                eps=config.norm_epsilon,
                dtype=config.dtype)

        self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                                      eps=config.norm_epsilon,
                                      dtype=config.dtype)

    def forward(self,
                hidden_states: Tensor,
                attention_mask: Optional[Tensor] = None,
                use_cache: bool = False,
                kv_cache_params: Optional[KeyValueCacheParams] = None,
                attention_params: Optional[AttentionParams] = None,
                lora_layer_params: Optional[LoraParams] = None,
                next_layer_input_layernorm_args=None):
        # assert not (
        #     default_net().plugin_config.reduce_fusion and self.has_residual_mlp
        # ), "Custom all reduce and residual mlp can't be enabled at the same time."
        if default_net(
        ).plugin_config.reduce_fusion and self.local_layer_idx > 0:
            hidden_states, residual = hidden_states  #FIXME:AN need to check if appropriate residual value is hidden state is pulled out.
        else:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)

        attention_output = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            use_cache=use_cache,
            kv_cache_params=kv_cache_params,
            attention_params=attention_params,
            norm_before_bmm1=True,
            lora_layer_params=lora_layer_params,
            all_reduce_params=AllReduceParams(
                fusion_op=AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM
                if default_net().plugin_config.reduce_fusion else
                AllReduceFusionOp.NONE,
                residual=residual,
                norm_weight=self.post_layernorm.weight.value,
                norm_pre_residual_weight=self.pre_feedforward_layernorm.weight.
                value if self.config.inter_layernorms else None,
                eps=self.post_layernorm.eps))

        if use_cache:
            attention_output, presents = attention_output

        if default_net().plugin_config.reduce_fusion:
            hidden_states, residual = attention_output
        else:
            if self.config.inter_layernorms:
                attention_output = self.post_layernorm(attention_output)
            hidden_states = residual + attention_output
            residual = hidden_states
            if self.config.inter_layernorms:
                hidden_states = self.pre_feedforward_layernorm(hidden_states)
            else:
                hidden_states = self.post_layernorm(hidden_states)

        if next_layer_input_layernorm_args is not None:
            hidden_states = self.mlp(
                hidden_states,
                lora_layer_params=lora_layer_params,
                all_reduce_params=AllReduceParams(
                    fusion_op=AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM
                    if default_net().plugin_config.reduce_fusion else
                    AllReduceFusionOp.NONE,
                    residual=residual,
                    norm_weight=next_layer_input_layernorm_args[0],
                    norm_pre_residual_weight=self.post_feedforward_layernorm.
                    weight.value,
                    eps=next_layer_input_layernorm_args[1]))
        else:
            hidden_states = self.mlp(hidden_states,
                                     lora_layer_params=lora_layer_params)

            if self.config.inter_layernorms:
                hidden_states = self.post_feedforward_layernorm(hidden_states)
            hidden_states = residual + hidden_states
        if use_cache:
            return (hidden_states, presents)
        return hidden_states


class GemmaModel(Module):

    def __init__(self, config: GemmaConfig) -> None:
        super().__init__()

        self.mapping = config.mapping
        if self.mapping.is_first_pp_rank():
            self.vocab_embedding = Embedding(config.vocab_size,
                                             config.hidden_size,
                                             dtype=config.dtype)

        self.layers = DecoderLayerList(GemmaDecoderLayer, config)

        if self.mapping.is_last_pp_rank():
            self.ln_f = RmsNorm(normalized_shape=config.hidden_size,
                                eps=config.norm_epsilon,
                                dtype=config.dtype)
        self.hidden_size = config.hidden_size

    def forward(self,
                input_ids,
                position_ids=None,
                use_cache=False,
                attention_mask=None,
                kv_cache_params=None,
                attention_params=None,
                hidden_states=None,
                prompt_embedding_table: Optional[Tensor] = None,
                prompt_tasks: Optional[Tensor] = None,
                prompt_vocab_size: Optional[Tensor] = None,
                lora_params=None):

        ptuning_args = [
            prompt_embedding_table, prompt_tasks, prompt_vocab_size
        ] if prompt_embedding_table is not None else []

        if self.mapping.is_first_pp_rank():
            hidden_states = self.vocab_embedding(input_ids, *ptuning_args)
            hidden_states = cast(hidden_states * math.sqrt(self.hidden_size),
                                 hidden_states.dtype)
        else:
            hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
        hidden_states = self.layers.forward(
            hidden_states,
            use_cache=use_cache,
            attention_mask=attention_mask,
            kv_cache_params=kv_cache_params,
            attention_params=attention_params,
            lora_params=lora_params,
        )

        if use_cache:
            hidden_states, presents = hidden_states

        if self.mapping.is_last_pp_rank():
            hidden_states = self.ln_f(hidden_states)
        else:
            hidden_states = send(hidden_states, self.mapping.next_pp_rank())

        if use_cache:
            return (hidden_states, tuple(presents))
        return hidden_states


[docs] class GemmaForCausalLM(DecoderModelForCausalLM): config_class = GemmaConfig def __init__(self, config: GemmaConfig): transformer = GemmaModel(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) if config.mapping.is_last_pp_rank(): lm_head = ColumnLinear(config.hidden_size, vocab_size_padded, bias=False, dtype=config.dtype, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, gather_output=True) else: lm_head = None self.quant_mode = config.quant_mode self.mapping = config.mapping super().__init__(config, transformer, lm_head) @staticmethod def _load_gemma_weights_from_hf(hf_model_dir: "HfConfigOrDir", trt_llm_config: GemmaConfig, *, load_model_on_cpu: bool) -> Weights: """`AutoModelForCausalLM.from_pretrained` will parse the correct gemma, whether Gemma or Gemma2 or future versions.""" import transformers hf_gemma = transformers.AutoModelForCausalLM.from_pretrained( hf_model_dir, device_map="cpu" if load_model_on_cpu else "auto", torch_dtype='auto', ) weights = load_gemma_weights_from_hf_model(hf_gemma, trt_llm_config) del hf_gemma return weights
[docs] @classmethod def from_hugging_face(cls, hf_model_dir: "HfConfigOrDir", dtype='float16', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, load_model_on_cpu: bool = True, **kwargs): config = GemmaConfig.from_hugging_face(hf_config_or_dir=hf_model_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **kwargs) model = GemmaForCausalLM(config) weights = cls._load_gemma_weights_from_hf( hf_model_dir, config, load_model_on_cpu=load_model_on_cpu) model.load(weights) return model
NATIVE_QUANT_FLOW = { QuantAlgo.W8A16, QuantAlgo.W4A16, QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN, QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN, QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN, QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN }
[docs] @classmethod def assert_valid_quant_algo(cls, quant_algo: Optional[QuantAlgo]): allowed_quant_values = { None } | cls.NATIVE_QUANT_FLOW | MODELOPT_FLOW_QUANTIZATIONS assert quant_algo in allowed_quant_values, f"{quant_algo} isn't in the allowed `QuantAlgo` values for this model: {allowed_quant_values}"
[docs] @classmethod def quantize( cls, hf_model_dir: str, output_dir: str, dtype: str = 'float16', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, *, gemma_config_kwargs: Dict[str, Any] = None, **quantize_kwargs: Dict[str, Any], ): config = GemmaConfig.from_hugging_face(hf_model_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **(gemma_config_kwargs or {})) quant_algo = config.quantization.quant_algo if quant_algo is None and config.quantization.kv_cache_quant_algo is None: raise ValueError( "There is no point in calling `quantize()` if both `quant_algo` and `kv_cache_quant_algo` are `None`" ) elif quant_algo in MODELOPT_FLOW_QUANTIZATIONS: super().quantize(hf_model_dir, output_dir, dtype=config.dtype, mapping=config.mapping, quant_config=config.quantization, **quantize_kwargs) elif quant_algo in cls.NATIVE_QUANT_FLOW: save_config(config, output_dir=output_dir, log=True) for config in config.for_each_rank(): hf_weights = cls._load_gemma_weights_from_hf( hf_model_dir, config) ranked_weights = non_modelopt_quantize_if_needed( hf_weights, model_dir=hf_model_dir, quantize_modifiers=QuantizeModifiers(), trt_llm_config=config) save_checkpoint( output_dir=output_dir, weights=ranked_weights, rank=config.mapping.rank, ) del hf_weights else: cls.assert_valid_quant_algo(quant_algo)
[docs] def use_lora(self, lora_config: LoraConfig) -> None: return use_lora( self, lora_config) # Use the default trtllm->hf module mapping