from typing import Optional
from ..._common import default_net
from ..._utils import pad_vocab_size
from ...functional import recv, send
from ...layers import (Attention, AttentionMaskType, ColumnLinear, Embedding,
                       GatedMLP, LayerNorm, PositionEmbeddingType)
from ...mapping import Mapping
from ...module import Module
from ..model_weights_loader import ModelWeightsLoader
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
                              QuantConfig)
from .config import CohereConfig
class CohereDecoderLayer(Module):
    def __init__(self, config: CohereConfig, layer_idx: int):
        super().__init__()
        self.layer_idx = layer_idx
        self.config = config
        self.input_layernorm = LayerNorm(normalized_shape=config.hidden_size,
                                         eps=config.norm_epsilon,
                                         bias=False,
                                         dtype=config.dtype)
        layers_range = config.mapping.pp_layers(config.num_hidden_layers)
        self.local_layer_idx = layer_idx - layers_range[0]
        self.attention = Attention(
            local_layer_idx=self.local_layer_idx,
            hidden_size=config.hidden_size,
            attention_head_size=config.head_size,
            num_attention_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            max_position_embeddings=config.max_position_embeddings,
            dtype=config.dtype,
            attention_mask_type=AttentionMaskType.causal,
            bias=config.attn_bias,
            position_embedding_type=PositionEmbeddingType.rope_gptj,
            rotary_embedding_base=config.rotary_base,
            tp_group=config.mapping.tp_group,
            tp_size=config.mapping.tp_size,
            tp_rank=config.mapping.tp_rank,
            qk_layernorm=config.qk_layernorm,
            layernorm_share=False,
            eps=config.norm_epsilon,
            quant_mode=config.quant_mode)
        self.mlp = GatedMLP(hidden_size=config.hidden_size,
                            ffn_hidden_size=config.intermediate_size,
                            hidden_act=config.hidden_act,
                            dtype=config.dtype,
                            bias=False,
                            tp_group=config.mapping.tp_group,
                            tp_size=config.mapping.tp_size,
                            quant_mode=config.quant_mode)
    def forward(self,
                hidden_states,
                attention_mask=None,
                use_cache=False,
                spec_decoding_params=None,
                kv_cache_params=None,
                attention_params=None):
        assert not (
            default_net().plugin_config.reduce_fusion
        ), "Custom all reduce and residual mlp can't be enabled at the same time."
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        attention_output = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            use_cache=use_cache,
            spec_decoding_params=spec_decoding_params,
            kv_cache_params=kv_cache_params,
            attention_params=attention_params)
        if use_cache:
            attention_output, presents = attention_output
        mlp_output = self.mlp(hidden_states)
        hidden_states = residual + attention_output + mlp_output
        if use_cache:
            return (hidden_states, presents)
        return hidden_states
class CohereModel(Module):
    def __init__(self, config: CohereConfig) -> None:
        super().__init__()
        self.mapping = config.mapping
        if self.mapping.is_first_pp_rank():
            self.vocab_embedding = Embedding(config.vocab_size,
                                             config.hidden_size,
                                             dtype=config.dtype,
                                             tp_group=config.mapping.tp_group,
                                             tp_size=config.mapping.tp_size,
                                             tp_rank=config.mapping.tp_rank)
        self.layers = DecoderLayerList(CohereDecoderLayer, config)
        if self.mapping.is_last_pp_rank():
            self.ln_f = LayerNorm(normalized_shape=config.hidden_size,
                                  eps=config.norm_epsilon,
                                  bias=False,
                                  dtype=config.dtype)
    def forward(
        self,
        input_ids=None,
        position_ids=None,
        use_cache=False,
        attention_mask=None,
        spec_decoding_params=None,
        kv_cache_params=None,
        attention_params=None,
        hidden_states=None,
    ):
        if self.mapping.is_first_pp_rank():
            hidden_states = self.vocab_embedding(input_ids)
        else:
            hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
        hidden_states = self.layers.forward(
            hidden_states,
            use_cache=use_cache,
            attention_mask=attention_mask,
            kv_cache_params=kv_cache_params,
            attention_params=attention_params,
            spec_decoding_params=spec_decoding_params)
        if use_cache:
            hidden_states, presents = hidden_states
        if self.mapping.is_last_pp_rank():
            hidden_states = self.ln_f(hidden_states)
        else:
            hidden_states = send(hidden_states, self.mapping.next_pp_rank())
        if use_cache:
            return (hidden_states, presents)
        return hidden_states
[docs]
class CohereForCausalLM(DecoderModelForCausalLM):
    config_class = CohereConfig
    def __init__(self, config: CohereConfig):
        transformer = CohereModel(config)
        vocab_size_padded = pad_vocab_size(config.vocab_size,
                                           config.mapping.tp_size)
        if config.mapping.is_last_pp_rank():
            lm_head = ColumnLinear(config.hidden_size,
                                   vocab_size_padded,
                                   bias=False,
                                   dtype=config.dtype,
                                   tp_group=config.mapping.tp_group,
                                   tp_size=config.mapping.tp_size,
                                   gather_output=True)
        else:
            lm_head = None
        self.quant_mode = config.quant_mode
        self.mapping = config.mapping
        super().__init__(config, transformer, lm_head)
[docs]
    @classmethod
    def from_hugging_face(cls,
                          hf_model_or_dir: str,
                          dtype: str = 'auto',
                          mapping: Optional[Mapping] = None,
                          quant_config: Optional[QuantConfig] = None,
                          **kwargs):
        ''' Create a CohereForCausalLM object from give parameters
        '''
        config = CohereConfig.from_hugging_face(hf_model_or_dir,
                                                dtype=dtype,
                                                mapping=mapping,
                                                quant_config=quant_config,
                                                **kwargs)
        model = cls(config)
        custom_dict = {
            'q_layernorm': 'q_norm',
            'k_layernorm': 'k_norm',
        }
        loader = ModelWeightsLoader(hf_model_or_dir, custom_dict)
        loader.generate_tllm_weights(model)
        return model