Source code for tensorrt_llm.models.chatglm.config

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union

import torch

from ..._utils import torch_dtype_to_str
from ...mapping import Mapping
from ..modeling_utils import PretrainedConfig, QuantConfig

GLM_VERSIONS = ['glm4', 'chatglm3', 'chatglm2', 'chatglm', 'glm']
GLM_ARCH1_VERSIONS = ['chatglm', 'glm']
GLM_ARCH2_VERSIONS = ['glm4', 'chatglm3', 'chatglm2']


[docs] class ChatGLMConfig(PretrainedConfig): def __init__(self, *, chatglm_version: str = 'chatglm3', add_bias_linear: bool = False, add_qkv_bias: bool = True, apply_query_key_layer_scaling: bool = False, apply_residual_connection_post_layernorm: bool = False, rmsnorm: bool = True, rotary_pct: float = 0.5, rotary_base: float = 10000.0, rotary_scaling: Optional[dict] = None, **kwargs): self.chatglm_version = chatglm_version self.add_bias_linear = add_bias_linear self.add_qkv_bias = add_qkv_bias self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.rmsnorm = rmsnorm self.rotary_pct = rotary_pct self.rotary_base = rotary_base self.rotary_scaling = rotary_scaling super().__init__(**kwargs)
[docs] def to_dict(self): output = super().to_dict() # Serialize the fields added in ChatGLMConfig output['chatglm_version'] = self.chatglm_version output['add_bias_linear'] = self.add_bias_linear output['add_qkv_bias'] = self.add_qkv_bias output[ 'apply_query_key_layer_scaling'] = self.apply_query_key_layer_scaling output[ 'apply_residual_connection_post_layernorm'] = self.apply_residual_connection_post_layernorm output['rmsnorm'] = self.rmsnorm output['rotary_pct'] = self.rotary_pct output['rotary_base'] = self.rotary_base output['rotary_scaling'] = self.rotary_scaling return output
[docs] @classmethod def from_hugging_face( cls, hf_config_or_dir: Union[str, 'transformers.PretrainedConfig'], dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): import transformers # load hugging face config if isinstance(hf_config_or_dir, transformers.PretrainedConfig): hf_config = hf_config_or_dir else: hf_config_dir = str(hf_config_or_dir) hf_config = transformers.AutoConfig.from_pretrained( hf_config_dir, trust_remote_code=True) logits_dtype = kwargs.pop('logits_dtype', 'float32') use_parallel_embedding = kwargs.pop('use_parallel_embedding', False) embedding_sharding_dim = kwargs.pop('embedding_sharding_dim', 0) share_embedding_table = kwargs.pop('share_embedding_table', False) chatglm_version = kwargs.pop('chatglm_version', None) # get chatglm version if chatglm_version is None: print("Inferring chatglm version from path...") for v in GLM_VERSIONS: if v in hf_config._name_or_path: chatglm_version = v break if 'glm_4' in hf_config._name_or_path.replace("-", "_"): chatglm_version = 'glm4' assert chatglm_version in GLM_VERSIONS print(f"Chatglm version: {chatglm_version}") if chatglm_version == 'glm': hf_config.num_kv_heads = hf_config.num_attention_heads hf_config.ffn_hidden_size = hf_config.hidden_size * 4 hf_config.hidden_act = 'gelu' hf_config.layernorm_epsilon = 1e-5 hf_config.max_position_embeddings = hf_config.max_sequence_length hf_config.add_bias_linear = True hf_config.add_qkv_bias = True hf_config.apply_query_key_layer_scaling = False hf_config.apply_residual_connection_post_layernorm = False hf_config.rmsnorm = False hf_config.rope_ratio = 1.0 elif chatglm_version == 'chatglm': hf_config.num_kv_heads = hf_config.num_attention_heads hf_config.ffn_hidden_size = hf_config.inner_hidden_size hf_config.hidden_act = 'gelu' hf_config.max_position_embeddings = hf_config.max_sequence_length hf_config.add_bias_linear = True hf_config.add_qkv_bias = True hf_config.apply_query_key_layer_scaling = False hf_config.apply_residual_connection_post_layernorm = False hf_config.rmsnorm = False hf_config.rope_ratio = 1.0 else: hf_config.vocab_size = hf_config.padded_vocab_size hf_config.num_kv_heads = hf_config.multi_query_group_num hf_config.hidden_act = 'swiglu' hf_config.max_position_embeddings = hf_config.seq_length hf_config.rmsnorm = getattr(hf_config, 'rmsnorm', 1.0) hf_config.rope_ratio = getattr(hf_config, 'rope_ratio', 1.0) if chatglm_version == 'glm': position_embedding_type = 'learned_absolute' elif chatglm_version == 'chatglm': position_embedding_type = 'chatglm' elif chatglm_version in GLM_ARCH2_VERSIONS: position_embedding_type = 'rope_gptj' rotary_base = 10000.0 rotary_embedding_scaling = None if chatglm_version == 'chatglm2': if hf_config.rope_ratio > 1: rotary_embedding_scaling = { 'type': 'linear', 'factor': hf_config.rope_ratio } elif chatglm_version == 'chatglm3' or chatglm_version == 'glm4': rotary_base *= hf_config.rope_ratio if dtype == 'auto': dtype = getattr(hf_config, 'torch_dtype', None) if dtype is None: dtype = 'float16' if isinstance(dtype, torch.dtype): dtype = torch_dtype_to_str(dtype) if dtype == 'float32': dtype = 'float16' return cls( architecture=hf_config.architectures[0], dtype=dtype, logits_dtype=logits_dtype, num_hidden_layers=hf_config.num_layers, num_attention_heads=hf_config.num_attention_heads, num_key_value_heads=hf_config.num_kv_heads, hidden_size=hf_config.hidden_size, intermediate_size=hf_config.ffn_hidden_size, norm_epsilon=hf_config.layernorm_epsilon, vocab_size=hf_config.vocab_size, position_embedding_type=position_embedding_type, max_position_embeddings=hf_config.max_position_embeddings, rotary_pct=0.5, rotary_base=rotary_base, rotary_scaling=rotary_embedding_scaling, hidden_act=hf_config.hidden_act, use_parallel_embedding=use_parallel_embedding, embedding_sharding_dim=embedding_sharding_dim, share_embedding_table=share_embedding_table, quantization=quant_config, mapping=mapping, chatglm_version=chatglm_version, add_bias_linear=hf_config.add_bias_linear, add_qkv_bias=hf_config.add_qkv_bias, apply_query_key_layer_scaling=False, apply_residual_connection_post_layernorm=hf_config. apply_residual_connection_post_layernorm, rmsnorm=hf_config.rmsnorm, )