# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.from..._utilsimportpad_vocab_sizefrom...functionalimportTensorfrom...layersimport(MLP,Attention,AttentionMaskType,ColumnLinear,Embedding,LayerNorm,PositionEmbeddingType)from...moduleimportModulefrom..modeling_utilsimport(DecoderLayerList,DecoderModelForCausalLM,PretrainedConfig)classBloomDecoderLayer(Module):def__init__(self,config:PretrainedConfig,layer_idx:int):super().__init__()self.layer_idx=layer_idxself.config=confighidden_size=config.hidden_sizedtype=config.dtypetp_group=config.mapping.tp_grouptp_size=config.mapping.tp_sizetp_rank=config.mapping.tp_rankself.input_layernorm=LayerNorm(normalized_shape=hidden_size,dtype=dtype)layers_range=config.mapping.pp_layers(config.num_hidden_layers)local_layer_idx=layer_idx-layers_range[0]self.attention=Attention(local_layer_idx=local_layer_idx,hidden_size=hidden_size,num_attention_heads=config.num_attention_heads,num_kv_heads=config.num_key_value_heads,num_layers=config.num_hidden_layers,dtype=dtype,attention_mask_type=AttentionMaskType.causal,position_embedding_type=PositionEmbeddingType.alibi,bias=True,tp_group=tp_group,tp_size=tp_size,tp_rank=tp_rank,reorder=True,quant_mode=config.quant_mode)mlp_hidden_size=hidden_size*4ifconfig.intermediate_sizeisNoneelseconfig.intermediate_sizeself.mlp=MLP(hidden_size=hidden_size,ffn_hidden_size=mlp_hidden_size,hidden_act='gelu',dtype=dtype,bias=True,tp_group=tp_group,tp_size=tp_size,quant_mode=config.quant_mode)self.post_layernorm=LayerNorm(normalized_shape=hidden_size,dtype=dtype)defforward(self,hidden_states:Tensor,attention_mask=None,use_cache=False,kv_cache_params=None,attention_params=None):assertisinstance(hidden_states,Tensor)residual=hidden_stateshidden_states=self.input_layernorm(hidden_states)attention_output=self.attention(hidden_states,attention_mask=attention_mask,use_cache=use_cache,kv_cache_params=kv_cache_params,attention_params=attention_params)ifuse_cache:attention_output,presents=attention_outputhidden_states=residual+attention_outputresidual=hidden_stateshidden_states=self.post_layernorm(hidden_states)hidden_states=self.mlp(hidden_states)hidden_states=residual+hidden_statesifuse_cache:return(hidden_states,presents)returnhidden_states