Source code for tensorrt_llm.models.phi3.model

from typing import Optional, Union

import numpy as np
from transformers import AutoModelForCausalLM

from ..._utils import pad_vocab_size
from ...functional import PositionEmbeddingType, Tensor
from ...layers import (MLP, MOE, Attention, AttentionMaskType,
                       BlockSparseAttnParams, ColumnLinear, Embedding,
                       LayerNorm, MoeConfig, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
                              PretrainedConfig, QuantConfig)
from .config import Phi3Config
from .convert import load_weights_from_hf_model


class Phi3DecoderLayer(Module):

    def __init__(self, config: PretrainedConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        tp_group = config.mapping.tp_group
        tp_size = config.mapping.tp_size

        attention_mask_type = AttentionMaskType.causal
        block_sparse_attn_params = BlockSparseAttnParams()
        q_scaling = 1.0
        self.gegelu_limit = None

        self.small_variant = config.architecture == "Phi3SmallForCausalLM"
        self.moe_variant = config.architecture == "PhiMoEForCausalLM"
        if self.small_variant:
            self.gegelu_limit = config.gegelu_limit

            # MuP uses norm_factor=attention_head_size (rather than sqrt(attention_head_size))
            # We achieve this using q_scaling = sqrt(attention_head_size)
            hidden_size = config.hidden_size
            num_attention_heads = config.num_attention_heads
            attention_head_size = hidden_size / num_attention_heads
            q_scaling = attention_head_size**.5

            block_sparse = (
                (layer_idx + 1) % config.dense_attention_every_n_layers) != 0
            attention_mask_type = AttentionMaskType.blocksparse if block_sparse else AttentionMaskType.causal

            block_sparse_attn_params = BlockSparseAttnParams(
                config.blocksparse_block_size,
                config.blocksparse_homo_head_pattern,
                config.blocksparse_num_local_blocks,
                config.blocksparse_vertical_stride)

        if self.small_variant or self.moe_variant:
            self.input_layernorm = LayerNorm(
                normalized_shape=config.hidden_size,
                dtype=config.dtype,
                eps=config.norm_epsilon)
            self.post_layernorm = LayerNorm(normalized_shape=config.hidden_size,
                                            dtype=config.dtype,
                                            eps=config.norm_epsilon)
        else:
            self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                                           eps=config.norm_epsilon,
                                           dtype=config.dtype)
            self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                                          eps=config.norm_epsilon,
                                          dtype=config.dtype)

        layers_range = config.mapping.pp_layers(config.num_hidden_layers)
        local_layer_idx = layer_idx - layers_range[0]
        position_embedding_type = PositionEmbeddingType.rope_gpt_neox

        rope_scaling_short_factors, rope_scaling_long_factors = None, None
        rope_scaling_short_mscale, rope_scaling_long_mscale = None, None
        original_max_position_embeddings = config.max_position_embeddings

        if hasattr(config, "longrope_scaling_short_factors"):
            rope_scaling_short_factors = np.asarray(
                config.longrope_scaling_short_factors).astype(np.float32)
            rope_scaling_long_factors = np.asarray(
                config.longrope_scaling_long_factors).astype(np.float32)

            original_max_position_embeddings = config.original_max_position_embeddings
            position_embedding_type = PositionEmbeddingType.long_rope

            if self.small_variant or self.moe_variant:
                rope_scaling_short_mscale = config.longrope_short_mscale
                rope_scaling_long_mscale = config.longrope_long_mscale

        self.attention = Attention(
            local_layer_idx=local_layer_idx,
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            position_embedding_type=position_embedding_type,
            rotary_embedding_base=config.rotary_base,
            max_position_embeddings=config.max_position_embeddings,
            dtype=config.dtype,
            attention_mask_type=attention_mask_type,
            bias=self.small_variant or self.moe_variant,
            q_scaling=q_scaling,
            tp_group=tp_group,
            tp_size=tp_size,
            quant_mode=config.quant_mode,
            rope_scaling_short_factors=rope_scaling_short_factors,
            rope_scaling_long_factors=rope_scaling_long_factors,
            rope_scaling_short_mscale=rope_scaling_short_mscale,
            rope_scaling_long_mscale=rope_scaling_long_mscale,
            original_max_position_embeddings=original_max_position_embeddings,
            block_sparse_params=block_sparse_attn_params)

        ClsMLP = MLP
        mlp_kwargs = {}
        if hasattr(config, "moe"):
            ClsMLP = MOE
            moe_config = MoeConfig()
            for key, value in config.moe.items():
                setattr(moe_config, key, value)
            mlp_kwargs = {
                "moe_config": moe_config,
                "mapping": config.mapping,
            }

        self.mlp = ClsMLP(hidden_size=config.hidden_size,
                          ffn_hidden_size=config.intermediate_size,
                          hidden_act=config.hidden_act,
                          dtype=config.dtype,
                          tp_group=tp_group,
                          tp_size=tp_size,
                          quant_mode=config.quant_mode,
                          bias=self.small_variant,
                          **mlp_kwargs)

    def forward(
        self,
        hidden_states: Tensor,
        attention_mask=None,
        use_cache=False,
        kv_cache_params=None,
        attention_params=None,
        lora_layer_params=None,
    ):

        input_layernorm_output = self.input_layernorm(hidden_states)
        attention_output = self.attention(
            input_layernorm_output,
            attention_mask=attention_mask,
            use_cache=use_cache,
            kv_cache_params=kv_cache_params,
            attention_params=attention_params,
            norm_before_bmm1=not self.small_variant,
            lora_layer_params=lora_layer_params,
        )

        if use_cache:
            attention_output, presents = attention_output

        post_attention_input = hidden_states + attention_output
        post_attention_output = self.post_layernorm(post_attention_input)
        if self.small_variant:
            feed_forward_hidden_states = self.mlp(
                post_attention_output,
                gegelu_limit=self.gegelu_limit,
                lora_layer_params=lora_layer_params)
        else:
            feed_forward_hidden_states = self.mlp(
                post_attention_output, lora_layer_params=lora_layer_params)
        hidden_states = post_attention_input + feed_forward_hidden_states
        if use_cache:
            return (hidden_states, presents)
        return hidden_states


[docs] class Phi3Model(Module): def __init__(self, config: PretrainedConfig): super().__init__() self.vocab_embedding = Embedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, dtype=config.dtype) self.layers = DecoderLayerList(Phi3DecoderLayer, config) self.small_variant = config.architecture == "Phi3SmallForCausalLM" self.moe_variant = config.architecture == "PhiMoEForCausalLM" if self.small_variant or self.moe_variant: self.ln_f = LayerNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype) if self.small_variant: self.mup_embedding_multiplier = config.mup_embedding_multiplier else: self.ln_f = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype)
[docs] def forward( self, input_ids: Tensor, position_ids=None, use_cache=False, attention_mask=None, kv_cache_params=None, attention_params=None, prompt_embedding_table=None, prompt_tasks=None, prompt_vocab_size=None, lora_params=None, ): args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size ] if prompt_embedding_table is not None else [] hidden_states = self.vocab_embedding(input_ids, *args) if self.small_variant and self.mup_embedding_multiplier > 0.0: hidden_states = hidden_states * self.mup_embedding_multiplier hidden_states = self.layers( hidden_states, use_cache=use_cache, attention_mask=attention_mask, kv_cache_params=kv_cache_params, attention_params=attention_params, lora_params=lora_params, ) if use_cache: hidden_states, presents = hidden_states hidden_states = self.ln_f(hidden_states) if use_cache: return (hidden_states, tuple(presents)) return hidden_states
[docs] class Phi3ForCausalLM(DecoderModelForCausalLM): config_class = Phi3Config def __init__(self, config: PretrainedConfig): transformer = Phi3Model(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) self.moe_variant = config.architecture == "PhiMoEForCausalLM" lm_head = ColumnLinear(config.hidden_size, vocab_size_padded, bias=self.moe_variant, dtype=config.dtype, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, gather_output=True) self.trtllm_modules_to_hf_modules = { "attn_qkv": ["qkv_proj", "query_key_value"], "attn_dense": ["o_proj", "dense"], "mlp_h_to_4h": ["gate_up_proj", "up_proj"], "mlp_4h_to_h": "down_proj", } super().__init__(config, transformer, lm_head)
[docs] @classmethod def from_hugging_face( cls, hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'], dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): import transformers assert hf_model_or_dir is not None use_preloading = isinstance(hf_model_or_dir, transformers.PreTrainedModel) if use_preloading: hf_model = hf_model_or_dir hf_config_or_dir = hf_model.config else: hf_model_dir = hf_model_or_dir hf_config_or_dir = hf_model_or_dir config = Phi3Config.from_hugging_face(hf_config_or_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **kwargs) if not use_preloading: hf_model = AutoModelForCausalLM.from_pretrained( hf_model_dir, torch_dtype="auto", trust_remote_code=True) assert isinstance(hf_model, transformers.PreTrainedModel) weights = load_weights_from_hf_model(hf_model, config) model = cls(config) model.load(weights) return model
[docs] def use_lora(self, lora_config: LoraConfig): use_lora(self, lora_config, self.trtllm_modules_to_hf_modules)