Source code for tensorrt_llm.models.modeling_utils

import argparse
import copy
import dataclasses
import json
import os
import re
from enum import IntFlag, auto
from functools import cached_property
from pathlib import Path
from typing import (TYPE_CHECKING, Callable, Dict, Generator, List, Optional,
                    Union)

import numpy as np
import safetensors
import torch

from .._common import default_net
from .._utils import (QuantModeWrapper, get_init_params, numpy_to_torch,
                      release_gc, str_dtype_to_torch, str_dtype_to_trt,
                      trt_dtype_to_torch)
from ..bindings import KVCacheType
from ..functional import (PositionEmbeddingType, Tensor,
                          gather_last_token_logits, tanh)
from ..layers import (MLP, AttentionParams, Embedding, FusedGatedMLP,
                      FusedRgLru, GatedMLP, KeyValueCacheParams, LoraParams,
                      PromptTuningEmbedding, RgLru)
from ..layers.attention import Attention, BertAttention
from ..layers.linear import ColumnLinear, Linear, RowLinear
from ..layers.lora import Lora
from ..layers.moe import MOE, MoeOOTB
from ..logger import logger
from ..mapping import Mapping
from ..module import Module, ModuleList
from ..parameter import Parameter
from ..plugin import init_all_reduce_helper
from ..quantization import QuantMode
from ..quantization.layers import (Fp8RowwiseFusedGatedMLP, Fp8RowwiseGatedMLP,
                                   WeightOnlyGroupwiseQuantLinear,
                                   WeightOnlyGroupwiseQuantRowLinear,
                                   WeightOnlyQuantLinear,
                                   WeightOnlyQuantRowLinear)
from ..quantization.mode import (KV_CACHE_QUANT_ALGO_LIST, QUANT_ALGO_LIST,
                                 W8A8_SQ_PLUGIN_LIST, QuantAlgo)
from ..top_model_mixin import TopModelMixin
from .convert_utils import weight_only_quantize_dict
from .generation_mixin import GenerationMixin


@dataclasses.dataclass(kw_only=True, frozen=True)
class Gemma2ConfigGroup:
    query_pre_attn_scalar: int
    final_logit_softcapping: Optional[float]
    attn_logit_softcapping: Optional[float]

    @classmethod
    def keys(cls):
        return {f.name for f in dataclasses.fields(cls)}


if TYPE_CHECKING:
    from typing import Type, TypeVar

    from typing_extensions import Self

    ConfigGroups = Union[Gemma2ConfigGroup]
    """Groupings of config where, if one of said properties exists, we assume all of the properties exist (even if they are `None`)"""
    CG = TypeVar("CG", bound=ConfigGroups)


[docs] class SpeculativeDecodingMode(IntFlag): # [WARNING] KEEP BELOW DEFINITION IN SYNC WITH cpp/tensorrt_llm/runtime/speculativeDecodingMode.h NONE = auto() DRAFT_TOKENS_EXTERNAL = auto() MEDUSA = auto() LOOKAHEAD_DECODING = auto() EXPLICIT_DRAFT_TOKENS = auto() EAGLE = auto()
[docs] @staticmethod def from_arguments(args: argparse.Namespace): if args.speculative_decoding_mode is None: return SpeculativeDecodingMode.NONE elif args.speculative_decoding_mode == "draft_tokens_external": return SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL elif args.speculative_decoding_mode == "medusa": return SpeculativeDecodingMode.MEDUSA elif args.speculative_decoding_mode == "lookahead_decoding": return SpeculativeDecodingMode.LOOKAHEAD_DECODING elif args.speculative_decoding_mode == "explicit_draft_tokens": return SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS elif args.speculative_decoding_mode == "eagle": logger.warning(f"EAGLE is not supported yet. Do not use it.") return SpeculativeDecodingMode.EAGLE else: assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode
@dataclasses.dataclass class QuantConfig: '''Serializable quantization configuration class, part of the PretrainedConfig ''' quant_algo: Optional[QuantAlgo] = None kv_cache_quant_algo: Optional[QuantAlgo] = None group_size: Optional[int] = 128 smoothquant_val: float = 0.5 clamp_val: Optional[List[float]] = None use_meta_recipe: bool = False has_zero_point: Optional[bool] = False pre_quant_scale: Optional[bool] = False exclude_modules: Optional[List[str]] = None @property def use_plugin_sq(self): return self.quant_algo in W8A8_SQ_PLUGIN_LIST @cached_property def quant_mode(self) -> QuantModeWrapper: quant_mode_list = [ QuantMode.from_quant_algo( self.quant_algo, self.kv_cache_quant_algo, ) ] return QuantModeWrapper(quant_mode_list) @cached_property def layer_quant_mode(self) -> QuantMode: return QuantMode.from_quant_algo( self.quant_algo, self.kv_cache_quant_algo, ) @property def requires_calibration(self): return self.quant_algo in (set(QUANT_ALGO_LIST) - { QuantAlgo.W8A16, QuantAlgo.W4A16, QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN }) or self.kv_cache_quant_algo in KV_CACHE_QUANT_ALGO_LIST @property def requires_modelopt_quantization(self): if self.quant_algo in [ QuantAlgo.W4A16_AWQ, QuantAlgo.FP8, QuantAlgo.W8A8_SQ_PER_CHANNEL, QuantAlgo.W4A8_AWQ, QuantAlgo.MIXED_PRECISION ]: return True elif self.quant_algo is None and self.kv_cache_quant_algo == QuantAlgo.FP8: return True else: return False def get_quant_cfg(self, module_name=None): return self def get_modelopt_qformat(self): algo_to_modelopt_map = { QuantAlgo.W8A16: "int8_wo", QuantAlgo.W4A16: "int4_wo", QuantAlgo.W4A16_AWQ: "int4_awq", QuantAlgo.W4A8_AWQ: 'w4a8_awq', QuantAlgo.FP8: 'fp8', QuantAlgo.W8A8_SQ_PER_CHANNEL: 'int8_sq', } assert self.quant_algo != QuantAlgo.MIXED_PRECISION, f"We don't support mixed precision in QuantConfig" if self.quant_algo is not None: assert self.quant_algo in algo_to_modelopt_map, f"We don't use Modelopt for quantization algorithm {self.quant_algo}, you probably shall not call this" return algo_to_modelopt_map[self.quant_algo] else: return 'full_prec' def get_modelopt_kv_cache_dtype(self): algo_to_modelopt_map = { QuantAlgo.FP8: 'fp8', QuantAlgo.INT8: 'int8', } if self.kv_cache_quant_algo is not None: assert self.kv_cache_quant_algo in algo_to_modelopt_map, f"We don't use Modelopt for quantization algorithm {self.kv_cache_quant_algo}, you probably shall not call this" return algo_to_modelopt_map[self.kv_cache_quant_algo] else: return None @classmethod def from_dict(cls, config: dict): obj = cls(**config) return obj def to_dict(self): return dataclasses.asdict(self) @dataclasses.dataclass class LayerQuantConfig(QuantConfig): quant_algo: Optional[QuantConfig] = None kv_cache_quant_algo: Optional[QuantConfig] = None quantized_layers: Optional[Dict[str, QuantConfig]] = None exclude_modules: Optional[List[str]] = None def __init__(self, *, quant_algo: Optional[QuantConfig] = None, kv_cache_quant_algo: Optional[QuantConfig] = None, quantized_layers: Optional[Dict[str, QuantConfig]] = None, exclude_modules: Optional[List[str]] = None, **kwargs): self.quant_algo = quant_algo self.quantized_layers = quantized_layers self.kv_cache_quant_algo = kv_cache_quant_algo self.exclude_modules = exclude_modules self.auto_quant_mode = {} for name, layer_config in self.quantized_layers.items(): self.auto_quant_mode.update({ name: QuantMode.from_quant_algo( layer_config.quant_algo, self.kv_cache_quant_algo, ) }) for key in kwargs: logger.warning( f"Warning: Unrecognized parameter '{key}' with value '{kwargs[key]}'" ) @cached_property def quant_mode(self): quant_mode_list = list(set(self.auto_quant_mode.values())) return QuantModeWrapper(quant_mode_list) @property def layer_quant_mode(self) -> Dict[str, QuantMode]: return self.auto_quant_mode @cached_property def auto_quant_list(self): quant_list = [] for _, layer_config in self.quantized_layers.items(): quant_list.append(layer_config.quant_algo) return list(set(quant_list)) @classmethod def from_dict(cls, config: dict): quantized_layers = config.pop('quantized_layers', {}) quantized_layers_dict = { layer_name: QuantConfig(**layer_config) for layer_name, layer_config in quantized_layers.items() } obj = cls(quantized_layers=quantized_layers_dict, **config) return obj def get_quant_cfg(self, module_name): assert module_name in self.quantized_layers.keys(), \ "module {module_name} should be included in `quantized_layers` in AutoQuant mode" return self.quantized_layers[module_name] def get_modelopt_qformat(self): algo_to_modelopt_map = { QuantAlgo.W4A16_AWQ: "int4_awq", QuantAlgo.W4A8_AWQ: 'w4a8_awq', QuantAlgo.FP8: 'fp8', QuantAlgo.W8A8_SQ_PER_CHANNEL: 'int8_sq', } assert self.quant_algo == QuantAlgo.MIXED_PRECISION, f"We only support mixed precision quantization in LayerQuantConfig" autoq_format = ','.join( [algo_to_modelopt_map[item] for item in self.auto_quant_list]) return autoq_format def to_dict(self): output = copy.deepcopy(self.__dict__) output.pop('auto_quant_mode', None) output.pop('quant_mode', None) output.pop('exclude_modules', None) for name, per_layer_config in output['quantized_layers'].items(): per_layer_config = per_layer_config.to_dict() per_layer_config.pop('exclude_modules') output['quantized_layers'][name] = per_layer_config return output
[docs] class PretrainedConfig: def __init__(self, *, architecture: str, dtype: str, hidden_size: int, num_hidden_layers: int, num_attention_heads: int, vocab_size: Optional[int] = None, hidden_act: str = 'gelu', logits_dtype: str = 'float32', norm_epsilon: float = 1e-5, position_embedding_type: Union[ PositionEmbeddingType, str] = PositionEmbeddingType.learned_absolute, max_position_embeddings: Optional[int] = None, num_key_value_heads: Optional[int] = None, intermediate_size: Optional[int] = None, mapping: Optional[Union[Mapping, dict]] = None, quantization: Optional[Union[QuantConfig, dict]] = None, use_parallel_embedding: bool = False, embedding_sharding_dim: int = 0, share_embedding_table: bool = False, head_size: Optional[int] = None, qk_layernorm: bool = False, **kwargs): self.architecture = architecture self.dtype = dtype self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.logits_dtype = logits_dtype self.norm_epsilon = norm_epsilon if isinstance(position_embedding_type, str): position_embedding_type = PositionEmbeddingType.from_string( position_embedding_type) assert isinstance(position_embedding_type, PositionEmbeddingType) self.position_embedding_type = position_embedding_type self.max_position_embeddings = max_position_embeddings if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads if intermediate_size is None: intermediate_size = hidden_size * 4 self.intermediate_size = intermediate_size if mapping is None: mapping = Mapping() elif isinstance(mapping, dict): mapping = Mapping.from_dict(mapping) assert isinstance(mapping, Mapping) self.mapping = mapping if quantization is None: quantization = QuantConfig() elif isinstance(quantization, dict): quantization = QuantConfig.from_dict(quantization) assert isinstance(quantization, QuantConfig) self.quantization = quantization self.use_parallel_embedding = use_parallel_embedding self.embedding_sharding_dim = embedding_sharding_dim self.share_embedding_table = share_embedding_table if share_embedding_table and mapping.tp_size > 1: if (not use_parallel_embedding) or (use_parallel_embedding and embedding_sharding_dim == 1): raise NotImplementedError( "For tensor parallelism, sharing the embedding table must set" \ "use_parallel_embedding=True and embedding_sharding_dim=0" ) if share_embedding_table and mapping.pp_size > 1: raise NotImplementedError( "Embedding table cannot be shared for pipeline parallelism") if share_embedding_table and mapping.cp_size > 1: raise NotImplementedError( "Embedding table cannot be shared for context parallelism") if head_size is None: head_size = hidden_size // num_attention_heads self.head_size = head_size self.qk_layernorm = qk_layernorm for key, value in kwargs.items(): try: setattr(self, key, value) logger.warning( f"Implicitly setting {self.__class__.__name__}.{key} = {value}" ) except AttributeError as err: raise err @property def kv_dtype(self): # TODO: need to align the kv dtype # now assume the kv cache is for all layers if self.quant_mode.has_int8_kv_cache(): return 'int8' elif self.quant_mode.has_fp8_kv_cache(): return 'fp8' else: return self.dtype
[docs] def set_if_not_exist(self, key, value): if not hasattr(self, key): setattr(self, key, value)
[docs] @classmethod def from_dict(cls, config: dict): # Maybe we need AutoConfig for this from . import MODEL_MAP model_cls = MODEL_MAP[config['architecture']] config_cls = getattr(model_cls, 'config_class', cls) return config_cls(**config)
[docs] def to_dict(self): output = copy.deepcopy(self.__dict__) output['position_embedding_type'] = str(self.position_embedding_type) output['mapping'] = self.mapping.to_dict() output['mapping'].pop('rank') output['quantization'] = self.quantization.to_dict() return output
[docs] @classmethod def from_json_file(cls, config_file: str): with open(config_file) as f: config = json.load(f) obj = cls.from_dict(config) if obj.quantization.quant_algo == QuantAlgo.MIXED_PRECISION: try: layer_config_path = str(config_file).replace( 'config.json', 'quant_cfg.json') obj.to_layer_quant_config(layer_config_path) except Exception as e: raise RuntimeError( f"Encounter error '{e}' for read quantization config '{layer_config_path}'" ) return obj
[docs] @classmethod def from_checkpoint(cls, ckpt_dir: str): return cls.from_json_file(os.path.join(ckpt_dir, 'config.json'))
[docs] def to_json_file(self, config_file: str): with open(config_file, 'w') as f: json.dump(self.to_dict(), f, indent=4)
[docs] def to_layer_quant_config(self, config_file: str): with open(config_file) as f: config = json.load(f) self.quantization = LayerQuantConfig.from_dict(config)
@property def quant_mode(self): return self.quantization.quant_mode @property def quant_algo(self): return self.quantization.quant_algo
[docs] def get_quant_cfg(self, module_name: str): return self.quantization.get_quant_cfg(module_name)
[docs] def set_rank(self, rank): self.mapping = Mapping(self.mapping.world_size, rank=rank, cp_size=self.mapping.cp_size, tp_size=self.mapping.tp_size, pp_size=self.mapping.pp_size, moe_tp_size=self.mapping.moe_tp_size, moe_ep_size=self.mapping.moe_ep_size, gpus_per_node=self.mapping.gpus_per_node)
[docs] def get_config_group(self, group_cls: "Type[CG]") -> "CG": cfg = {k: v for k, v in self.to_dict().items() if k in group_cls.keys()} return group_cls(**cfg)
[docs] def has_config_group(self, group_cls: "Type[CG]") -> "bool": return all(hasattr(self, key) for key in group_cls.keys())
[docs] def for_each_rank(self) -> "Generator[Self, None, None]": for rank in range(self.mapping.world_size): config_copy = copy.deepcopy(self) config_copy.set_rank(rank) yield config_copy
class DecoderLayerList(ModuleList): def __init__(self, cls, config): self.num_hidden_layers = config.num_hidden_layers self.layer_list = config.mapping.pp_layers(config.num_hidden_layers) super().__init__([cls(config, idx) for idx in self.layer_list]) def forward(self, hidden_states, use_cache=False, attention_mask=None, kv_cache_params=None, attention_params=None, position_ids=None, lora_params=None, spec_decoding_params=None, vision_token_mask=None): kv_cache_params.fill_none_tensor_list(len(self.layer_list)) if use_cache: presents = [] for layer_idx, (layer, past) in enumerate( zip(self, kv_cache_params.past_key_value)): lora_layer_params = None if lora_params is not None and lora_params.lora_ranks is not None: lora_layer_params = lora_params.get_layer_params(layer_idx) kwargs = {} if position_ids is not None: kwargs['position_ids'] = position_ids if vision_token_mask is not None: kwargs['vision_token_mask'] = vision_token_mask if lora_layer_params is not None: kwargs['lora_layer_params'] = lora_layer_params if spec_decoding_params is not None: kwargs['spec_decoding_params'] = spec_decoding_params if default_net().plugin_config.reduce_fusion: if layer_idx < self.layer_list[-1]: kwargs['next_layer_input_layernorm_args'] = ( self[layer_idx + 1 - self.layer_list[0]].input_layernorm.weight.value, self[layer_idx + 1 - self.layer_list[0]].input_layernorm.eps) else: kwargs['next_layer_input_layernorm_args'] = None hidden_states = layer( hidden_states, use_cache=use_cache, attention_mask=attention_mask, kv_cache_params=KeyValueCacheParams( past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, host_max_attention_window_sizes=kv_cache_params. host_max_attention_window_sizes, host_sink_token_length=kv_cache_params. host_sink_token_length, kv_cache_block_offsets=kv_cache_params. kv_cache_block_offsets, host_kv_cache_block_offsets=kv_cache_params. host_kv_cache_block_offsets, host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers, host_kv_cache_pool_mapping=kv_cache_params. host_kv_cache_pool_mapping, cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params, **kwargs) if use_cache: presents.append(hidden_states[1]) hidden_states = hidden_states[0] if use_cache: return hidden_states, presents return hidden_states class PostInitCaller(type): def __call__(cls, *args, **kwargs): obj = type.__call__(cls, *args, **kwargs) obj.__post_init__() return obj
[docs] class PretrainedModel(Module, GenerationMixin, TopModelMixin, metaclass=PostInitCaller): def __init__(self, config: PretrainedConfig): super().__init__() init_all_reduce_helper() self.config = config def __post_init__(self): from ..quantization.quantize import quantize quantize(self, self.config.quantization) # Currently, use_parallel_embedding and share_embedding_table must be enabled before weight loading; # otherwise, the model will be inconsistent with the weights loaded from checkpoint. optimize_model( self, use_parallel_embedding=self.config.use_parallel_embedding, share_embedding_table=self.config.share_embedding_table, )
[docs] def release(self): release_gc()
def __del__(self): self.release()
[docs] def check_config(self, config): raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." )
[docs] @classmethod def from_config(cls, config: PretrainedConfig): return cls(config)
[docs] @classmethod def from_checkpoint( cls, ckpt_dir: str, rank: Optional[int] = None, config: Optional[PretrainedConfig] = None, *, preprocess_weights_hook: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None): if config is None: config = PretrainedConfig.from_json_file( os.path.join(ckpt_dir, 'config.json')) if rank is not None: config.set_rank(rank) rank = config.mapping.rank weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors') assert os.path.isfile(weights_path) weights = safetensors.torch.load_file(weights_path) is_checkpoint_pruned = getattr(config, 'is_pruned', False) if preprocess_weights_hook is not None: weights = preprocess_weights_hook(weights) weights = preprocess_weights(weights, config, from_pruned=is_checkpoint_pruned) model = cls(config) model.load(weights, from_pruned=is_checkpoint_pruned) return model
[docs] def load(self, weights, from_pruned=False): expected_names = set() required_names = set() for name, param in self.named_parameters(): expected_names.add(name) if not param.is_inited(): required_names.add(name) provided_names = set(weights.keys()) if not required_names.issubset(provided_names): raise RuntimeError( f"Required but not provided tensors:{required_names.difference(provided_names)}" ) if not provided_names.issubset(expected_names): logger.warning( f"Provided but not expected tensors: {provided_names.difference(expected_names)}" ) for name, param in self.named_parameters(): if name in provided_names: if not from_pruned: try: param.value = weights[name] except Exception as e: raise RuntimeError( f"Encounter error '{e}' for parameter '{name}'") else: param.set_value_or_dummy(weights[name])
[docs] def save_checkpoint(self, output_dir, save_config=True): # multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks rank = self.config.mapping.rank weights = { name: numpy_to_torch(param.raw_value) for name, param in self.named_parameters() } safetensors.torch.save_file( weights, os.path.join(output_dir, f'rank{rank}.safetensors')) if save_config: self.config.to_json_file(os.path.join(output_dir, 'config.json'))
[docs] def prepare_inputs( self, max_batch_size, max_input_len, max_seq_len, max_num_tokens, use_cache, max_beam_width: int = 1, opt_num_tokens: int = None, prompt_embedding_table_size: int = 0, position_encoding_2d: bool = False, max_draft_len: int = 0, speculative_decoding_draft_tokens_external: bool = False, spec_decoding_is_generation_length_variable: bool = False, gather_context_logits: bool = False, gather_generation_logits: bool = False, lora_target_modules: List[str] = None, opt_batch_size: int = 0): '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the ranges of the dimensions of when using TRT dynamic shapes. @return: a list contains values which can be fed into the self.forward() ''' # Prepare inputs remove_input_padding = default_net().plugin_config.remove_input_padding use_gpt_attention_plugin = default_net( ).plugin_config.gpt_attention_plugin use_gemm_plugin = default_net().plugin_config.gemm_plugin paged_kv_cache = default_net().plugin_config.paged_kv_cache tokens_per_block = default_net().plugin_config.tokens_per_block use_lora_plugin = default_net().plugin_config.lora_plugin multiple_profiles = default_net().plugin_config.multiple_profiles streamingllm = default_net().plugin_config.streamingllm pp_reduce_scatter = default_net().plugin_config.pp_reduce_scatter kv_cache_type = None if not use_cache: kv_cache_type = KVCacheType.DISABLED else: if paged_kv_cache: kv_cache_type = KVCacheType.PAGED else: kv_cache_type = KVCacheType.CONTINUOUS model_inputs = self.prepare_basic_inputs( max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_input_len=max_input_len, max_seq_len=max_seq_len, hidden_size=self.config.hidden_size, num_kv_heads=self.config.num_key_value_heads, head_size=self.config.head_size, num_layers=self.config.num_hidden_layers, kv_dtype=str_dtype_to_trt(self.config.kv_dtype), remove_input_padding=remove_input_padding, use_gpt_attention_plugin=use_gpt_attention_plugin, use_gemm_plugin=use_gemm_plugin, kv_cache_type=kv_cache_type, tokens_per_block=tokens_per_block, num_heads=self.config.num_attention_heads, max_num_tokens=max_num_tokens, opt_num_tokens=opt_num_tokens, dtype=str_dtype_to_trt(self.config.dtype), prompt_embedding_table_size=prompt_embedding_table_size, position_encoding_2d=position_encoding_2d, mapping=self.config.mapping, gather_context_logits=gather_context_logits, gather_generation_logits=gather_generation_logits, use_lora_plugin=use_lora_plugin, max_draft_len=max_draft_len, speculative_decoding_draft_tokens_external= speculative_decoding_draft_tokens_external, spec_decoding_is_generation_length_variable= spec_decoding_is_generation_length_variable, lora_target_modules=lora_target_modules, multiple_profiles=multiple_profiles, streamingllm=streamingllm, opt_batch_size=opt_batch_size, pp_reduce_scatter=pp_reduce_scatter) result = { 'input_ids': model_inputs['input_ids'], 'position_ids': model_inputs['position_ids'], 'use_cache': kv_cache_type != KVCacheType.DISABLED, 'last_token_ids': model_inputs['last_token_ids'], 'attention_mask': model_inputs['attention_mask'], 'kv_cache_params': KeyValueCacheParams( past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], host_max_attention_window_sizes=model_inputs[ 'host_max_attention_window_sizes'], host_sink_token_length=model_inputs['host_sink_token_length'], kv_cache_block_offsets=model_inputs['kv_cache_block_offsets'], host_kv_cache_block_offsets=model_inputs[ 'host_kv_cache_block_offsets'], host_kv_cache_pool_pointers=model_inputs[ 'host_kv_cache_pool_pointers'], host_kv_cache_pool_mapping=model_inputs[ 'host_kv_cache_pool_mapping'], cache_indirection=model_inputs['cache_indirection'], ), 'attention_params': AttentionParams( sequence_length=model_inputs['sequence_length'], context_lengths=model_inputs['context_lengths'], host_context_lengths=model_inputs['host_context_lengths'], max_context_length=max_input_len, host_request_types=model_inputs['host_request_types'], host_runtime_perf_knobs=model_inputs['host_runtime_perf_knobs'], host_context_progress=model_inputs['host_context_progress'], ) } if prompt_embedding_table_size > 0: result['prompt_embedding_table'] = model_inputs[ 'prompt_embedding_table'] result['prompt_tasks'] = model_inputs['tasks'] result['prompt_vocab_size'] = model_inputs['prompt_vocab_size'] if model_inputs['hidden_states_input'] is not None: result['hidden_states'] = model_inputs['hidden_states_input'] if use_lora_plugin: result['lora_params'] = LoraParams( model_inputs['lora_ranks'], model_inputs['lora_weights_pointers'], host_context_lengths=model_inputs['host_context_lengths'], host_request_types=model_inputs['host_request_types']) if model_inputs['spec_decoding_params'] is not None: result['spec_decoding_params'] = model_inputs[ 'spec_decoding_params'] return result
[docs] @classmethod def quantize( cls, hf_model_dir: str, output_dir: str, dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, *, device: str = 'cuda', calib_dataset: str = 'cnn_dailymail', calib_batches: int = 512, calib_batch_size: int = 1, calib_max_seq_length: int = 512, random_seed: int = 1234, tokenizer_max_seq_length: int = 2048, **kwargs, ): config_cls = getattr(cls, 'config_class', None) if config_cls is None: raise NotImplementedError( f"{cls.__name__} has not implemented corresponding config class, which is needed for correct config parsing." ) config: PretrainedConfig = config_cls.from_hugging_face( hf_model_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **kwargs) if config.mapping.moe_ep_size > 1: raise NotImplementedError( "Quantization for expert parallelism is not supported") if not config.quantization.requires_modelopt_quantization: raise ValueError( f"The quant_config ({quant_config}) should not call modelopt quantization" ) from ..quantization import quantize_and_export quantize_and_export( model_dir=str(hf_model_dir), device=device, calib_dataset=calib_dataset, dtype=config.dtype, qformat=config.quantization.get_modelopt_qformat(), kv_cache_dtype=config.quantization.get_modelopt_kv_cache_dtype(), calib_size=calib_batches, batch_size=calib_batch_size, calib_max_seq_length=calib_max_seq_length, awq_block_size=config.quantization.group_size, output_dir=output_dir, tp_size=config.mapping.tp_size, pp_size=config.mapping.pp_size, seed=random_seed, tokenizer_max_seq_length=tokenizer_max_seq_length, )
class DecoderModelForCausalLM(PretrainedModel): def __init__(self, config: PretrainedConfig, transformer, lm_head): super().__init__(config) self.transformer = transformer self.lm_head = lm_head self.mup_width_multiplier = getattr(config, 'mup_width_multiplier', None) # Create constant attention parameters to be reused by all layers. Attention.create_attention_const_params(self, config) self.position_embedding_type = config.position_embedding_type def forward(self, input_ids: Tensor, position_ids=None, use_cache=False, last_token_ids=None, attention_mask=None, kv_cache_params=None, attention_params=None, hidden_states=None, prompt_embedding_table: Optional[Tensor] = None, prompt_tasks: Optional[Tensor] = None, prompt_vocab_size: Optional[Tensor] = None, lora_params=None, spec_decoding_params=None): # fill attention params. attention_params = Attention.fill_attention_params( self, attention_params) kwargs = { 'input_ids': input_ids, 'position_ids': position_ids, 'use_cache': use_cache, 'attention_mask': attention_mask, 'kv_cache_params': kv_cache_params, 'attention_params': attention_params, } if lora_params is not None: kwargs['lora_params'] = lora_params if hidden_states is not None: kwargs['hidden_states'] = hidden_states if prompt_embedding_table is not None: kwargs['prompt_embedding_table'] = prompt_embedding_table if prompt_tasks is not None: kwargs['prompt_tasks'] = prompt_tasks if prompt_vocab_size is not None: kwargs['prompt_vocab_size'] = prompt_vocab_size if spec_decoding_params is not None: kwargs['spec_decoding_params'] = spec_decoding_params hidden_states = self.transformer.forward(**kwargs) if use_cache: hidden_states, presents = hidden_states if self.config.mapping.is_last_pp_rank(): hidden_states = gather_last_token_logits( hidden_states, last_token_ids, default_net().plugin_config.remove_input_padding) # [batch_size, hidden_size] -> [batch_size, vocab_size] lm_logits = self.lm_head(hidden_states) if hasattr(self.config, 'output_multiplier_scale'): lm_logits *= getattr(self.config, 'output_multiplier_scale', 1) if self.mup_width_multiplier is not None: lm_logits = lm_logits / self.mup_width_multiplier if self.config.has_config_group(Gemma2ConfigGroup): softcap = self.config.get_config_group( Gemma2ConfigGroup).final_logit_softcapping if softcap: lm_logits = lm_logits * float(1 / softcap) lm_logits = tanh(lm_logits) * float(softcap) lm_logits.mark_output('logits', self.config.logits_dtype) else: hidden_states.mark_output('hidden_states_output', self.config.dtype) if use_cache and not default_net().plugin_config.paged_kv_cache: for i, present in zip( self.config.mapping.pp_layers( self.config.num_hidden_layers), presents): present.mark_output(f'present_key_value_{i}', self.config.kv_dtype) if self.config.mapping.is_last_pp_rank(): return (lm_logits, presents, hidden_states) return (hidden_states, presents) else: if self.config.mapping.is_last_pp_rank(): return lm_logits, hidden_states return hidden_states def fuse_gate_mlp( model: PretrainedModel, gemm_swiglu_plugin_dtype: Optional[str] = None, ) -> PretrainedModel: from ..quantization.quantize import fp8_quantize for name, mlp, layer in model.named_modules_with_parent(): if isinstance(mlp, GatedMLP): init_params = get_init_params(mlp) hidden_act = init_params["hidden_act"] if hidden_act not in ["silu", "gelu"]: logger.warning( f"fuse_gate_mlp cannot be done for {name} due to unsupported activation {hidden_act}. Skipping." ) continue init_params["inner_layernorm"] = mlp.inner_layernorm is not None fused_layer = FusedGatedMLP(**init_params) fc_name = name + '.fc' layer_quant_cfg = model.config.get_quant_cfg(fc_name) layer_quant_algo = layer_quant_cfg.quant_algo if layer_quant_algo != QuantAlgo.FP8 and layer_quant_algo is not None: continue if isinstance(model.config.quantization.exclude_modules, list) \ and fc_name in model.config.quantization.exclude_modules: layer_quant_algo = None if layer_quant_algo == QuantAlgo.FP8: fused_layer = fp8_quantize(fused_layer, layer_quant_cfg) if isinstance(mlp.dtype, str): dtype = str_dtype_to_torch(mlp.dtype) else: dtype = trt_dtype_to_torch(mlp.dtype) gate_weight = numpy_to_torch(mlp.gate.weight.raw_value) fc_weight = numpy_to_torch(mlp.fc.weight.raw_value) assert gate_weight.dtype == fc_weight.dtype need_qdq = gate_weight.dtype == torch.float8_e4m3fn gate_weight = gate_weight.to(dtype) fc_weight = fc_weight.to(dtype) # dequantize if needed if need_qdq: gate_weight = gate_weight.to(dtype) * numpy_to_torch( mlp.gate.weights_scaling_factor.raw_value) fc_weight = fc_weight.to(dtype) * numpy_to_torch( mlp.fc.weights_scaling_factor.raw_value) # concat fused_weight = torch.cat([gate_weight, fc_weight], dim=0) fused_weight_scaling_factor = numpy_to_torch( max( mlp.gate.weights_scaling_factor.raw_value, mlp.fc.weights_scaling_factor.raw_value, )) # quantize if needed if need_qdq: fused_weight = (fused_weight / fused_weight_scaling_factor).to( torch.float8_e4m3fn) if gemm_swiglu_plugin_dtype == 'fp8': # gemm_swiglu_plugin needs (k, n) weights # but weights should still be k-major for fp8 fused_layer.fused_fc.weight = Parameter( shape=(fused_layer.fused_fc.in_features, fused_layer.fused_fc.out_features), dtype='fp8') fused_layer.fused_fc.weight.value = fused_weight.view( fused_layer.fused_fc.in_features, fused_layer.fused_fc.out_features) else: fused_layer.fused_fc.weight.value = fused_weight fused_layer.fused_fc.weights_scaling_factor.value = fused_weight_scaling_factor fused_layer.fused_fc.activation_scaling_factor.value = max( mlp.gate.activation_scaling_factor.raw_value, mlp.fc.activation_scaling_factor.raw_value, ) elif layer_quant_algo is None: fused_layer.fused_fc.weight.value = np.concatenate( [ mlp.gate.weight.raw_value, mlp.fc.weight.raw_value, ], axis=0, ) if mlp.bias: fused_layer.fused_fc.bias.value = np.concatenate( [mlp.gate.bias.raw_value, mlp.fc.bias.raw_value], axis=0) else: raise ValueError(f'Unsupported quant algo: {layer_quant_algo}') fused_layer.proj = mlp.proj fused_layer.inner_layernorm = mlp.inner_layernorm _, mlp_name = name.rsplit('.', 1) setattr(layer, mlp_name, fused_layer) elif isinstance(mlp, Fp8RowwiseGatedMLP): init_params = get_init_params(mlp) hidden_act = init_params["hidden_act"] if hidden_act not in ["silu", "gelu"]: logger.warning( f"fuse_gate_mlp cannot be done for {name} due to unsupported activation {hidden_act}. Skipping." ) continue if mlp.clamp_val is not None: init_params["clamp_val"] = mlp.clamp_val.raw_value.tolist() fused_layer = Fp8RowwiseFusedGatedMLP(**init_params) fused_layer.fused_fc.weight.value = np.concatenate( [ mlp.gate.weight.raw_value, mlp.fc.weight.raw_value, ], axis=0, ) fused_layer.fused_fc.per_channel_scale.value = np.concatenate( [ mlp.gate.per_channel_scale.raw_value, mlp.fc.per_channel_scale.raw_value, ], axis=0, ) if mlp.bias: fused_layer.fused_fc.bias.value = np.concatenate( [mlp.gate.bias.raw_value, mlp.fc.bias.raw_value], axis=0) fused_layer.proj = mlp.proj _, mlp_name = name.rsplit('.', 1) setattr(layer, mlp_name, fused_layer) return model def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel: '''Split all the models' Attention layer's QKV GEMM into 3 GEMMs layer.q layer.k, layer.v and return the changed model ''' from ..quantization.quantize import quantize for name, layer in model.named_modules(): if isinstance(layer, Attention) and not layer.cross_attention: assert layer.tp_size == 1, "please disable manual tp when enable auto parallel" if layer.qkv is None: continue qkv_params = get_init_params(layer.qkv, ColumnLinear) qkv_params["bias"] = qkv_params["bias"] is not None qkv_params["strict_dtype"] = qkv_params.get( "strict_dtype") is not None q = ColumnLinear( **{ **qkv_params, "out_features": layer.tp_size * layer.num_attention_heads * layer.attention_head_size, }) k = ColumnLinear( **{ **qkv_params, "out_features": layer.tp_size * layer.num_attention_kv_heads * layer.attention_head_size, }) v = ColumnLinear( **{ **qkv_params, "out_features": layer.tp_size * layer.num_attention_kv_heads * layer.attention_head_size, }) layer_quant_cfg = model.config.get_quant_cfg(name + '.qkv') q = quantize(q, layer_quant_cfg) k = quantize(k, layer_quant_cfg) v = quantize(v, layer_quant_cfg) out_features = q.out_features + k.out_features + v.out_features if isinstance(layer.qkv, ( WeightOnlyQuantLinear, WeightOnlyQuantRowLinear, WeightOnlyGroupwiseQuantLinear, WeightOnlyGroupwiseQuantRowLinear, )): out_dim = 1 else: out_dim = 0 if layer.qkv.weight.is_inited(): qkv_weight = layer.qkv.weight.raw_value weights = np.split(qkv_weight, [ qkv_weight.shape[out_dim] * q.out_features // out_features, qkv_weight.shape[out_dim] * (q.out_features + k.out_features) // out_features, ], axis=out_dim) for gemm, weight in zip([q, k, v], weights): gemm.weight.value = weight if layer.qkv.bias is not None and layer.qkv.bias.is_inited(): qkv_bias = layer.qkv.bias.raw_value biases = np.split(qkv_bias, [ qkv_bias.shape[out_dim] * q.out_features // out_features, qkv_bias.shape[out_dim] * (q.out_features + k.out_features) // out_features, ], axis=out_dim) for gemm, bias in zip([q, k, v], biases): gemm.bias.value = bias for name, parameter in layer.qkv._parameters.items(): if name not in ["weight", "bias"]: for gemm in [q, k, v]: setattr(gemm, name, parameter) layer.q = q layer.k = k layer.v = v layer.qkv = None return model def fuse_rg_lru(model: PretrainedModel) -> PretrainedModel: for name, rg_lru, parent in model.named_modules_with_parent(): if isinstance(rg_lru, RgLru): fused_layer = FusedRgLru(**get_init_params(rg_lru)) fused_layer.gate.weight.value = np.concatenate( [ rg_lru.input_gate.weight.raw_value, rg_lru.recurrent_gate.weight.raw_value, ], axis=-1, ) fused_layer.gate.bias.value = np.concatenate( [ rg_lru.input_gate.bias.raw_value, rg_lru.recurrent_gate.bias.raw_value, ], axis=-1, ) fused_layer.recurrent_param.value = rg_lru.recurrent_param.raw_value rg_lru_name = name.rsplit('.', 1)[-1] setattr(parent, rg_lru_name, fused_layer) return model def set_prompt_tuning(model: PretrainedModel) -> PretrainedModel: '''Replace the given models embedding layer with a PromptTuningEmbedding layer in-place, return the changed model Pre-conditions: vocab_embedding exists Post-conditions: isinstance(vocab_embedding, PromptTuningEmbedding) ''' for name, embedding, parent in model.named_modules_with_parent(): layer_name = name.rsplit('.', 1)[-1] if layer_name == "vocab_embedding" and isinstance(embedding, Embedding): ptuning_embedding = PromptTuningEmbedding( **get_init_params(embedding)) ptuning_embedding.weight.value = embedding.weight.raw_value parent.vocab_embedding = ptuning_embedding return model def add_lora(model: PretrainedModel, max_lora_rank: Optional[int]) -> PretrainedModel: ''' Add lora layers to the Attention/BertAttention/Linear/RowLinear/FusedGatedMLP layers to the given model, return the changed model ''' for name, layer in model.named_modules(): max_rank = max_lora_rank if isinstance(layer, (Attention, BertAttention)): if max_rank is None: max_rank = min( layer.hidden_size, layer.num_attention_heads * layer.attention_head_size, layer.num_attention_kv_heads * layer.attention_head_size) layer.qkv_lora = Lora( in_hidden_size=layer.hidden_size, out_hidden_sizes=[ layer.num_attention_heads * layer.attention_head_size, layer.num_attention_kv_heads * layer.attention_head_size, layer.num_attention_kv_heads * layer.attention_head_size ], max_low_rank=max_rank, ) if isinstance(layer, (Linear, RowLinear)): if max_rank is None: max_rank = min(layer.in_features, layer.out_features) layer.lora = Lora( in_hidden_size=layer.in_features, out_hidden_sizes=[layer.out_features], max_low_rank=max_rank, ) if isinstance(layer, (MLP, FusedGatedMLP)): if max_rank is None: max_rank = min(layer.hidden_size, layer.ffn_hidden_size // layer.tp_size) layer.lora = Lora( in_hidden_size=layer.hidden_size, out_hidden_sizes=[ layer.ffn_hidden_size // layer.tp_size, layer.ffn_hidden_size // layer.tp_size ], max_low_rank=max_rank, ) if isinstance(layer, MOE): if max_rank is None: max_rank = min(layer.hidden_size, layer.ffn_hidden_size // layer.tp_size) layer.max_low_rank = max_rank return model def to_ootb_moe(model: PretrainedModel) -> PretrainedModel: ''' Use OOTB MoE instead of MoE plugin, return the changed model ''' for name, layer, parent in model.named_modules_with_parent(): if isinstance(layer, MOE): layer_name = name.rsplit('.', 1)[-1] ootb_layer = layer.to(MoeOOTB, model.config.quantization) setattr(parent, layer_name, ootb_layer) return model def parallelize_embedding(model: PretrainedModel) -> PretrainedModel: for name, embedding, parent in model.named_modules_with_parent(): layer_name = name.rsplit('.', 1)[-1] if isinstance(embedding, Embedding) and embedding.tp_group is None: init_params = get_init_params(embedding) init_params["tp_group"] = model.config.mapping.tp_group init_params["tp_size"] = model.config.mapping.tp_size init_params["tp_rank"] = model.config.mapping.tp_rank init_params["sharding_dim"] = model.config.embedding_sharding_dim new_embedding = embedding.__class__(**init_params) setattr(parent, layer_name, new_embedding) return model def share_embedding(model: PretrainedModel) -> PretrainedModel: lm_head = None vocab_embedding = None for name, layer in model.named_modules(): layer_name = name.rsplit('.', 1)[-1] if layer_name == "lm_head": lm_head = layer if layer_name == "vocab_embedding": vocab_embedding = layer if lm_head is not None and vocab_embedding is not None: break if lm_head is not None and vocab_embedding is not None: lm_head.weight = vocab_embedding.weight if (hasattr(vocab_embedding, "per_token_scale") and vocab_embedding.per_token_scale is not None): lm_head.per_channel_scale = vocab_embedding.per_token_scale return model def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel: for name, layer in model.named_modules(): if isinstance(layer, Attention) and hasattr( layer.dense, 'activation_scaling_factor'): scale = [1.0] / layer.dense.activation_scaling_factor.raw_value layer.attention_output_orig_quant_scale = Parameter( value=scale.astype(np.float32)) return model def optimize_model( model: PretrainedModel, use_parallel_embedding: bool = False, share_embedding_table: bool = False, use_ootb_moe: bool = False, use_fused_mlp: bool = False, gemm_swiglu_plugin_dtype: Optional[str] = None, use_fused_rg_lru: bool = False, use_unfused_qkv_gemm: bool = False, use_prompt_tuning: bool = False, use_lora: bool = False, max_lora_rank: Optional[int] = None, use_fp8_context_fmha: bool = False, ) -> PretrainedModel: """ Run optimization passes on model. There are dependencies between some passes, so we always run passes in the order of arguments to guarantee the execution order. """ # before weight loading if use_parallel_embedding: model = parallelize_embedding(model) if share_embedding_table: # if share_embedding_table is enabled, only one copy of the embedding table is store in converted ckpt # this pass is required to make lm_head.weight and vocab_embedding.weight point to the same tensor # however even if share_embedding_table is not enabled, trt would still only keep one copy of the table if the weights are identical model = share_embedding(model) # After weight loading if use_ootb_moe: model = to_ootb_moe(model) if use_fused_mlp: model = fuse_gate_mlp(model, gemm_swiglu_plugin_dtype) if use_fused_rg_lru: model = fuse_rg_lru(model) if use_unfused_qkv_gemm: model = unfuse_qkv_gemm(model) if use_prompt_tuning: model = set_prompt_tuning(model) if use_lora: model = add_lora(model, max_lora_rank) if use_fp8_context_fmha: model = set_fp8_context_fhma(model) return model def preprocess_perlayer_weights(weights, model_config, quant_algo, from_pruned=False): exclude_modules = model_config.quantization.exclude_modules # INT4_AWQ if quant_algo == QuantAlgo.W4A8_AWQ or quant_algo == QuantAlgo.W4A16_AWQ: preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm if quant_algo == QuantAlgo.W4A8_AWQ: activation_type = torch.float8_e4m3fn elif quant_algo == QuantAlgo.W4A16_AWQ: activation_type = torch.float16 for name, param in weights.items(): if from_pruned and param.numel() == 0: continue if name.endswith('weight') and param.dtype == torch.int8: dtype = torch.float16 if model_config.dtype == "bfloat16": dtype = torch.bfloat16 weights[name] = preprocessor(param.T.contiguous(), torch.quint4x2, activation_type).view(dtype) if name.endswith('weights_scaling_factor' ) and param.shape[0] > param.shape[1]: # TODO: refine on supporting ModelOpt HF-AWQ weights[name] = param.T.contiguous().to( str_dtype_to_torch(model_config.dtype)) if name.endswith('prequant_scaling_factor'): weights[name] = param.reshape(1, -1) if model_config.mapping.tp_rank > 0: if name.endswith('attention.dense.bias') or name.endswith( 'mlp.proj.bias'): weights[name] = torch.zeros_like(param) if quant_algo == QuantAlgo.W4A8_AWQ: for name in list(weights): if name.endswith('weights_scaling_factor'): activation_scaling_factor = weights.pop( name.replace('weights_scaling_factor', 'activation_scaling_factor')) weights_scaling_factor_2 = weights.pop( name.replace('weights_scaling_factor', 'weights_scaling_factor_2')) weights[name] /= weights_scaling_factor_2 weights[name.replace( 'weights_scaling_factor', 'prequant_scaling_factor')] /= activation_scaling_factor weights[name.replace( 'weights_scaling_factor', 'alpha' )] = activation_scaling_factor * weights_scaling_factor_2 # FP8 elif quant_algo == QuantAlgo.FP8: for name, param in weights.items(): if name.endswith('weight') and param.dtype == torch.int8: weights[name] = param.view(torch.float8_e4m3fn) # lm_head is not quantized to FP8 if "lm_head.weight" in weights: assert weights['lm_head.weight'].dtype == str_dtype_to_torch( model_config.dtype) weights.pop('lm_head.weights_scaling_factor', None) weights.pop('lm_head.activation_scaling_factor', None) elif quant_algo == QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN: for name, param in weights.items(): if name.endswith('weight') and param.dtype == torch.int8: weights[name] = param.view(torch.float8_e4m3fn) # lm_head is not quantized to FP8 if "lm_head.weight" in weights: assert weights['lm_head.weight'].dtype == str_dtype_to_torch( model_config.dtype) weights.pop('lm_head.weights_scaling_factor', None) weights.pop('lm_head.activation_scaling_factor', None) elif quant_algo in [QuantAlgo.W4A16, QuantAlgo.W8A16]: weights = weight_only_quantize_dict(weights=weights, quant_algo=quant_algo, exclude_modules=exclude_modules, plugin=True) def preprocess_weights(weights: Dict[str, torch.Tensor], model_config: PretrainedConfig, from_pruned=False) -> None: """This function in-place modifies weights and model_config, making them compatible with each other. Note: Typically, it should be called before model creation and weight loading. For example, preprocess_weights(weights, model_config) model = XXXForCausalLM(model_config) model.load(weights) """ quant_config = model_config.quantization quant_algo = quant_config.quant_algo pattern_info = ['fc', 'gate', 'proj', 'qkv', 'dense'] per_layer_weights = {} for name, param in weights.items(): in_mode = False for info in pattern_info: pattern = rf'(.*?{info}.*?)' pattern_match = re.match(pattern, name) if pattern_match: base_name = pattern_match.group(1) if base_name not in per_layer_weights.keys(): per_layer_weights[base_name] = {} per_layer_weights[base_name][name] = param in_mode = True break if not in_mode: # [lm_head.weight, ln_f.weight, vocab_embedding.weight] base_name = name.rsplit('.', 1)[0] if base_name not in per_layer_weights.keys(): per_layer_weights[base_name] = {} per_layer_weights[base_name][name] = param new_weights = {} for base_name, layer_weights in per_layer_weights.items(): if quant_algo != QuantAlgo.MIXED_PRECISION: layer_quant_algo = quant_algo else: if base_name not in quant_config.quantized_layers.keys(): new_weights.update(layer_weights) continue layer_quant_algo = quant_config.quantized_layers[ base_name].quant_algo preprocess_perlayer_weights(layer_weights, model_config, layer_quant_algo, from_pruned) new_weights.update(layer_weights) weights = new_weights for name, param in weights.items(): if model_config.architecture == 'GPTJForCausalLM': if model_config.mapping.tp_rank > 0: if 'attention.dense.bias' in name or 'mlp.proj.bias' in name: weights[name] = torch.zeros_like(param) # For share_embedding_table check_share_embedding(weights, model_config) return weights def check_share_embedding(weights: Dict[str, torch.Tensor], model_config: PretrainedConfig): if model_config.share_embedding_table: if "lm_head.weight" in weights: if weights["lm_head.weight"] is None: weights.pop("lm_head.weight") if "lm_head.weight" in weights and "transformer.vocab_embedding.weight" in weights: if (weights["lm_head.weight"] - weights["transformer.vocab_embedding.weight"]).any(): logger.warning( "lm_head.weight and transformer.vocab_embedding.weight are not identical, " "share_embedding_table cannot be enabled; setting share_embedding_table=False." ) model_config.share_embedding_table = False else: weights.pop("lm_head.weight") def get_kv_cache_type_from_legacy(use_cache: bool, paged_kv_cache: bool) -> KVCacheType: if use_cache: if paged_kv_cache: return KVCacheType.PAGED else: return KVCacheType.CONTINUOUS else: return KVCacheType.DISABLED def save_config(config: PretrainedConfig, *, output_dir: str, log: bool) -> None: config_path = Path(output_dir) / "config.json" if log: logger.debug(f"Saving TensorRT-LLM configuration to {config_path}") config_path.parent.mkdir(exist_ok=True, parents=True) config_path.write_text(json.dumps(config.to_dict(), indent=4)) def save_checkpoint(*, output_dir: str, weights: dict, rank: int) -> None: """ Checkpoint saver for weight loader.""" safetensors.torch.save_file( weights, os.path.join(output_dir, f'rank{rank}.safetensors'))