Source code for tensorrt_llm.models.deepseek_v1.model
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importosfromtypingimportOptionalfrom..._utilsimportpad_vocab_sizefrom...functionalimportTensor,non_gated_version,recv,sendfrom...layersimport(Attention,AttentionMaskType,ColumnLinear,Embedding,GatedMLP,PositionEmbeddingType,RmsNorm,SharedMoE)from...mappingimportMappingfrom...moduleimportModulefrom...pluginimportinit_all_reduce_helperfrom..model_weights_loaderimportModelWeightsLoaderfrom..modeling_utilsimport(DecoderLayerList,DecoderModelForCausalLM,PretrainedConfig)from.configimportDeepSeekV1Configfrom.convertimportconvert_deepseek,load_hf_deepseekclassDeepseekDecoderLayer(Module):def__init__(self,config:PretrainedConfig,layer_idx:int):super().__init__()self.layer_idx=layer_idxself.config=config### Input layernorm in Deepseek v1 is same as Llamaself.input_layernorm=RmsNorm(normalized_shape=config.hidden_size,eps=config.norm_epsilon,dtype=config.dtype)layers_range=config.mapping.pp_layers(config.num_hidden_layers)local_layer_idx=layer_idx-layers_range[0]### Deepseek v1 model with standard attentionself.attention=Attention(local_layer_idx=local_layer_idx,hidden_size=config.hidden_size,attention_head_size=config.head_size,num_attention_heads=config.num_attention_heads,num_kv_heads=config.num_key_value_heads,max_position_embeddings=config.max_position_embeddings,dtype=config.dtype,attention_mask_type=AttentionMaskType.causal,bias=False,position_embedding_type=PositionEmbeddingType.rope_gpt_neox,rotary_embedding_base=config.rotary_base,rotary_embedding_scaling=config.rotary_scaling,tp_group=config.mapping.tp_group,tp_size=config.mapping.tp_size,tp_rank=config.mapping.tp_rank,quant_mode=config.quant_mode)ClsMLP=GatedMLPmoe_config=config.moeifmoe_config.num_experts>0andlayer_idx>0:mlp_hidden_size=config.moe_intermediate_sizehidden_act=config.hidden_actmlp_kwargs={'moe_config':moe_config,'mapping':config.mapping}ifmoe_config.shared_expert_intermediate_size>0:ClsMLP=SharedMoEmlp_kwargs['use_shared_gate']=Falsemlp_kwargs['use_side_stream']=Falseelse:ClsMLP=MOEelse:ClsMLP=GatedMLPmlp_hidden_size=config.intermediate_sizehidden_act=non_gated_version(config.hidden_act)# back to non gated for dense layersmlp_kwargs={}self.mlp=ClsMLP(hidden_size=config.hidden_size,ffn_hidden_size=mlp_hidden_size,hidden_act=hidden_act,dtype=config.dtype,bias=False,tp_group=config.mapping.tp_group,tp_size=config.mapping.tp_size,quant_mode=config.quant_mode,**mlp_kwargs)### Pose layernorm in Deepseek v1 is same as Llama )self.post_layernorm=RmsNorm(normalized_shape=config.hidden_size,eps=config.norm_epsilon,dtype=config.dtype)defforward(self,hidden_states,attention_mask=None,use_cache=False,spec_decoding_params=None,kv_cache_params=None,attention_params=None):residual=hidden_stateshidden_states=self.input_layernorm(hidden_states)attention_output=self.attention(hidden_states,attention_mask=attention_mask,use_cache=use_cache,spec_decoding_params=spec_decoding_params,kv_cache_params=kv_cache_params,attention_params=attention_params)ifuse_cache:attention_output,presents=attention_outputhidden_states=residual+attention_outputresidual=hidden_stateshidden_states=self.post_layernorm(hidden_states)hidden_states=self.mlp(hidden_states)hidden_states=residual+hidden_statesifuse_cache:return(hidden_states,presents)returnhidden_statesclassDeepseekModel(Module):def__init__(self,config:PretrainedConfig)->None:super().__init__()init_all_reduce_helper()# enable use_customer_all_reduceself.mapping=config.mappingifself.mapping.is_first_pp_rank():self.vocab_embedding=Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype)self.layers=DecoderLayerList(DeepseekDecoderLayer,config)ifself.mapping.is_last_pp_rank():self.ln_f=RmsNorm(normalized_shape=config.hidden_size,eps=config.norm_epsilon,dtype=config.dtype)defforward(self,input_ids,position_ids=None,use_cache=False,attention_mask=None,spec_decoding_params=None,kv_cache_params=None,attention_params=None,hidden_states=None,prompt_embedding_table:Optional[Tensor]=None,prompt_tasks:Optional[Tensor]=None,prompt_vocab_size:Optional[Tensor]=None):ptuning_args=[prompt_embedding_table,prompt_tasks,prompt_vocab_size]ifprompt_embedding_tableisnotNoneelse[]ifself.mapping.is_first_pp_rank():hidden_states=self.vocab_embedding(input_ids,*ptuning_args)else:hidden_states=recv(hidden_states,self.mapping.prev_pp_rank())hidden_states=self.layers.forward(hidden_states,use_cache=use_cache,attention_mask=attention_mask,kv_cache_params=kv_cache_params,attention_params=attention_params,spec_decoding_params=spec_decoding_params)ifuse_cache:hidden_states,presents=hidden_statesifself.mapping.is_last_pp_rank():hidden_states=self.ln_f(hidden_states)else:hidden_states=send(hidden_states,self.mapping.next_pp_rank())ifuse_cache:return(hidden_states,tuple(presents))returnhidden_states