Source code for tensorrt_llm.models.commandr.model

from typing import Optional

from ..._common import default_net
from ..._utils import pad_vocab_size
from ...functional import recv, send
from ...layers import (Attention, AttentionMaskType, ColumnLinear, Embedding,
                       GatedMLP, LayerNorm, PositionEmbeddingType)
from ...mapping import Mapping
from ...module import Module
from ..model_weights_loader import ModelWeightsLoader
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
                              QuantConfig)
from .config import CohereConfig


class CohereDecoderLayer(Module):

    def __init__(self, config: CohereConfig, layer_idx: int):
        super().__init__()
        self.layer_idx = layer_idx
        self.config = config

        self.input_layernorm = LayerNorm(normalized_shape=config.hidden_size,
                                         eps=config.norm_epsilon,
                                         bias=False,
                                         dtype=config.dtype)

        layers_range = config.mapping.pp_layers(config.num_hidden_layers)
        self.local_layer_idx = layer_idx - layers_range[0]
        self.attention = Attention(
            local_layer_idx=self.local_layer_idx,
            hidden_size=config.hidden_size,
            attention_head_size=config.head_size,
            num_attention_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            max_position_embeddings=config.max_position_embeddings,
            dtype=config.dtype,
            attention_mask_type=AttentionMaskType.causal,
            bias=config.attn_bias,
            position_embedding_type=PositionEmbeddingType.rope_gptj,
            rotary_embedding_base=config.rotary_base,
            tp_group=config.mapping.tp_group,
            tp_size=config.mapping.tp_size,
            tp_rank=config.mapping.tp_rank,
            qk_layernorm=config.qk_layernorm,
            layernorm_share=False,
            eps=config.norm_epsilon,
            quant_mode=config.quant_mode)

        self.mlp = GatedMLP(hidden_size=config.hidden_size,
                            ffn_hidden_size=config.intermediate_size,
                            hidden_act=config.hidden_act,
                            dtype=config.dtype,
                            bias=False,
                            tp_group=config.mapping.tp_group,
                            tp_size=config.mapping.tp_size,
                            quant_mode=config.quant_mode)

    def forward(self,
                hidden_states,
                attention_mask=None,
                use_cache=False,
                spec_decoding_params=None,
                kv_cache_params=None,
                attention_params=None):
        assert not (
            default_net().plugin_config.reduce_fusion
        ), "Custom all reduce and residual mlp can't be enabled at the same time."

        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        attention_output = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            use_cache=use_cache,
            spec_decoding_params=spec_decoding_params,
            kv_cache_params=kv_cache_params,
            attention_params=attention_params)

        if use_cache:
            attention_output, presents = attention_output

        mlp_output = self.mlp(hidden_states)
        hidden_states = residual + attention_output + mlp_output

        if use_cache:
            return (hidden_states, presents)
        return hidden_states


class CohereModel(Module):

    def __init__(self, config: CohereConfig) -> None:
        super().__init__()

        self.mapping = config.mapping
        if self.mapping.is_first_pp_rank():
            self.vocab_embedding = Embedding(config.vocab_size,
                                             config.hidden_size,
                                             dtype=config.dtype,
                                             tp_group=config.mapping.tp_group,
                                             tp_size=config.mapping.tp_size,
                                             tp_rank=config.mapping.tp_rank)

        self.layers = DecoderLayerList(CohereDecoderLayer, config)

        if self.mapping.is_last_pp_rank():
            self.ln_f = LayerNorm(normalized_shape=config.hidden_size,
                                  eps=config.norm_epsilon,
                                  bias=False,
                                  dtype=config.dtype)

    def forward(
        self,
        input_ids=None,
        position_ids=None,
        use_cache=False,
        attention_mask=None,
        spec_decoding_params=None,
        kv_cache_params=None,
        attention_params=None,
        hidden_states=None,
    ):
        if self.mapping.is_first_pp_rank():
            hidden_states = self.vocab_embedding(input_ids)
        else:
            hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())

        hidden_states = self.layers.forward(
            hidden_states,
            use_cache=use_cache,
            attention_mask=attention_mask,
            kv_cache_params=kv_cache_params,
            attention_params=attention_params,
            spec_decoding_params=spec_decoding_params)

        if use_cache:
            hidden_states, presents = hidden_states

        if self.mapping.is_last_pp_rank():
            hidden_states = self.ln_f(hidden_states)
        else:
            hidden_states = send(hidden_states, self.mapping.next_pp_rank())

        if use_cache:
            return (hidden_states, presents)
        return hidden_states


[docs] class CohereForCausalLM(DecoderModelForCausalLM): config_class = CohereConfig def __init__(self, config: CohereConfig): transformer = CohereModel(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) if config.mapping.is_last_pp_rank(): lm_head = ColumnLinear(config.hidden_size, vocab_size_padded, bias=False, dtype=config.dtype, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, gather_output=True) else: lm_head = None self.quant_mode = config.quant_mode self.mapping = config.mapping super().__init__(config, transformer, lm_head)
[docs] @classmethod def from_hugging_face(cls, hf_model_or_dir: str, dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): ''' Create a CohereForCausalLM object from give parameters ''' config = CohereConfig.from_hugging_face(hf_model_or_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **kwargs) model = cls(config) custom_dict = { 'q_layernorm': 'q_norm', 'k_layernorm': 'k_norm', } loader = ModelWeightsLoader(hf_model_or_dir, custom_dict) loader.check_share_embedding(config) loader.generate_tllm_weights(model) return model