import argparse
import copy
import dataclasses
import fnmatch
import json
import os
import re
from enum import IntFlag, auto
from functools import cached_property
from pathlib import Path
from typing import (TYPE_CHECKING, Callable, Dict, Generator, List, Optional,
                    Union)
import numpy as np
import safetensors
import torch
from .._common import default_net
from .._utils import (QuantModeWrapper, get_init_params, numpy_to_torch,
                      release_gc, str_dtype_to_torch, str_dtype_to_trt,
                      trt_dtype_to_torch)
from ..bindings import KVCacheType
from ..bindings.executor import RuntimeDefaults
from ..functional import (PositionEmbeddingType, Tensor, allgather, constant,
                          cp_split_plugin, gather_last_token_logits,
                          index_select, tanh, view)
from ..layers import (MLP, AttentionParams, Embedding, FusedGatedMLP,
                      FusedRgLru, GatedMLP, KeyValueCacheParams, LoraParams,
                      PromptTuningEmbedding, RgLru)
from ..layers.attention import Attention, BertAttention
from ..layers.linear import ColumnLinear, Linear, RowLinear
from ..layers.lora import Dora, Lora
from ..layers.moe import MOE, MoeOOTB
from ..logger import logger
from ..mapping import Mapping
from ..module import Module, ModuleList
from ..parameter import Parameter
from ..plugin import init_all_reduce_helper
from ..quantization import QuantMode
from ..quantization.functional import preprocess_weights_for_mixed_gemm
from ..quantization.layers import (FP8Linear, Fp8RowwiseFusedGatedMLP,
                                   Fp8RowwiseGatedMLP,
                                   WeightOnlyGroupwiseQuantLinear,
                                   WeightOnlyGroupwiseQuantRowLinear,
                                   WeightOnlyQuantLinear,
                                   WeightOnlyQuantRowLinear)
from ..quantization.mode import (KV_CACHE_QUANT_ALGO_LIST, QUANT_ALGO_LIST,
                                 W8A8_SQ_PLUGIN_LIST, QuantAlgo)
from ..quantization.utils import fp4_utils
from ..top_model_mixin import TopModelMixin
from .convert_utils import weight_only_quantize_dict
from .generation_mixin import GenerationMixin
@dataclasses.dataclass(kw_only=True, frozen=True)
class Gemma2ConfigGroup:
    query_pre_attn_scalar: int
    final_logit_softcapping: Optional[float]
    attn_logit_softcapping: Optional[float]
    @classmethod
    def keys(cls):
        return {f.name for f in dataclasses.fields(cls)}
@dataclasses.dataclass(kw_only=True, frozen=True)
class Gemma3ConfigGroup:
    query_pre_attn_scalar: float
    final_logit_softcapping: Optional[float]
    sliding_window_pattern: int
    rope_local_base_freq: int
    sliding_window: int
    @classmethod
    def keys(cls):
        return {f.name for f in dataclasses.fields(cls)}
if TYPE_CHECKING:
    from typing import Type, TypeVar
    from typing_extensions import Self
    ConfigGroups = Union[Gemma2ConfigGroup, Gemma3ConfigGroup]
    """Groupings of config where, if one of said properties exists, we assume all of the properties exist (even if they are `None`)"""
    CG = TypeVar("CG", bound=ConfigGroups)
    RuntimeDefaultsIn = Optional[Union[RuntimeDefaults, dict]]
[docs]
class SpeculativeDecodingMode(IntFlag):
    # [WARNING] KEEP BELOW DEFINITION IN SYNC WITH cpp/tensorrt_llm/runtime/speculativeDecodingMode.h
    NONE = auto()
    DRAFT_TOKENS_EXTERNAL = auto()
    MEDUSA = auto()
    LOOKAHEAD_DECODING = auto()
    EXPLICIT_DRAFT_TOKENS = auto()
    EAGLE = auto()
    NGRAM = auto()
    USER_PROVIDED = auto()
    AUTO = auto()
[docs]
    @staticmethod
    def from_arguments(args: argparse.Namespace):
        if args.speculative_decoding_mode is None:
            return SpeculativeDecodingMode.NONE
        elif args.speculative_decoding_mode == "draft_tokens_external":
            return SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
        elif args.speculative_decoding_mode == "medusa":
            return SpeculativeDecodingMode.MEDUSA
        elif args.speculative_decoding_mode == "lookahead_decoding":
            return SpeculativeDecodingMode.LOOKAHEAD_DECODING
        elif args.speculative_decoding_mode == "explicit_draft_tokens":
            return SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS
        elif args.speculative_decoding_mode == "eagle":
            return SpeculativeDecodingMode.EAGLE
        elif args.speculative_decoding_mode == "ngram":
            return SpeculativeDecodingMode.NGRAM
        elif args.speculative_decoding_mode == "user_provided":
            return SpeculativeDecodingMode.USER_PROVIDED
        elif args.speculative_decoding_mode == "auto":
            return SpeculativeDecodingMode.AUTO
        else:
            assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode 
 
[docs]
@dataclasses.dataclass
class QuantConfig:
    """
    Serializable quantization configuration class, part of the PretrainedConfig.
    Args:
        quant_algo (tensorrt_llm.quantization.mode.QuantAlgo, optional): Quantization algorithm. Defaults to None.
        kv_cache_quant_algo (tensorrt_llm.quantization.mode.QuantAlgo, optional): KV cache quantization algorithm. Defaults to None.
        group_size (int): The group size for group-wise quantization. Defaults to 128.
        smoothquant_val (float): The smoothing parameter alpha used in smooth quant. Defaults to 0.5.
        clamp_val (List[float], optional): The clamp values used in FP8 rowwise quantization. Defaults to None.
        use_meta_recipe (bool): Whether to use Meta's recipe for FP8 rowwise quantization. Defaults to False.
        has_zero_point (bool): Whether to use zero point for quantization. Defaults to False.
        pre_quant_scale (bool): Whether to use pre-quant scale for quantization. Defaults to False.
        exclude_modules (List[str], optional): The module name patterns that are skipped in quantization. Defaults to None.
    """
    quant_algo: Optional[QuantAlgo] = None
    kv_cache_quant_algo: Optional[QuantAlgo] = None
    group_size: int = 128
    smoothquant_val: float = 0.5
    clamp_val: Optional[List[float]] = None
    use_meta_recipe: bool = False
    has_zero_point: bool = False
    pre_quant_scale: bool = False
    exclude_modules: Optional[List[str]] = None
    @cached_property
    def quant_mode(self) -> QuantModeWrapper:
        quant_mode_list = [
            QuantMode.from_quant_algo(
                self.quant_algo,
                self.kv_cache_quant_algo,
            )
        ]
        return QuantModeWrapper(quant_mode_list)
    @cached_property
    def layer_quant_mode(self) -> QuantMode:
        return QuantMode.from_quant_algo(
            self.quant_algo,
            self.kv_cache_quant_algo,
        )
    @property
    def _use_plugin_sq(self):
        return self.quant_algo in W8A8_SQ_PLUGIN_LIST
    @property
    def _requires_calibration(self):
        return self.quant_algo in (set(QUANT_ALGO_LIST) - {
            QuantAlgo.W8A16, QuantAlgo.W4A16,
            QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN
        }) or self.kv_cache_quant_algo in KV_CACHE_QUANT_ALGO_LIST
    @property
    def _requires_modelopt_quantization(self):
        if self.quant_algo in [
                QuantAlgo.NVFP4, QuantAlgo.FP8, QuantAlgo.W4A16_AWQ,
                QuantAlgo.W4A8_AWQ, QuantAlgo.W8A8_SQ_PER_CHANNEL,
                QuantAlgo.MIXED_PRECISION
        ]:
            return True
        elif self.quant_algo is None and self.kv_cache_quant_algo == QuantAlgo.FP8:
            return True
        else:
            return False
    def _get_quant_cfg(self, module_name=None):
        if self.exclude_modules is not None:
            for exclude_module in self.exclude_modules:
                if exclude_module == module_name or (
                        exclude_module.endswith('*')
                        and module_name.startswith(exclude_module[:-1])):
                    return LayerQuantConfig(quant_algo=None,
                                            quantized_layers={})
        return self
    def _get_modelopt_qformat(self):
        algo_to_modelopt_map = {
            QuantAlgo.W8A16: "int8_wo",
            QuantAlgo.W4A16: "int4_wo",
            QuantAlgo.NVFP4: "nvfp4",
            QuantAlgo.FP8: "fp8",
            QuantAlgo.W4A16_AWQ: "int4_awq",
            QuantAlgo.W4A8_AWQ: "w4a8_awq",
            QuantAlgo.W8A8_SQ_PER_CHANNEL: "int8_sq",
        }
        assert self.quant_algo != QuantAlgo.MIXED_PRECISION, f"We don't support mixed precision in QuantConfig"
        if self.quant_algo is not None:
            assert self.quant_algo in algo_to_modelopt_map, f"We don't use Modelopt for quantization algorithm {self.quant_algo}, you probably shall not call this"
            return algo_to_modelopt_map[self.quant_algo]
        else:
            return 'full_prec'
    def _get_modelopt_kv_cache_dtype(self):
        algo_to_modelopt_map = {
            QuantAlgo.FP8: 'fp8',
            QuantAlgo.INT8: 'int8',
        }
        if self.kv_cache_quant_algo is not None:
            assert self.kv_cache_quant_algo in algo_to_modelopt_map, f"We don't use Modelopt for quantization algorithm {self.kv_cache_quant_algo}, you probably shall not call this"
            return algo_to_modelopt_map[self.kv_cache_quant_algo]
        else:
            return None
[docs]
    def is_module_excluded_from_quantization(self, name: str) -> bool:
        """Check if the module is excluded from quantization.
        Args:
            name (str): The name of the module.
        Returns:
            bool: True if the module is excluded from quantization, False otherwise.
        """
        if self.exclude_modules is not None:
            for exclude_module in self.exclude_modules:
                if fnmatch.fnmatchcase(name, exclude_module):
                    return True
        return False 
[docs]
    @classmethod
    def from_dict(cls, config: dict) -> 'QuantConfig':
        """Create a QuantConfig instance from a dict.
        Args:
            config (dict): The dict used to create QuantConfig.
        Returns:
            tensorrt_llm.models.modeling_utils.QuantConfig: The QuantConfig created from dict.
        """
        obj = cls(**config)
        return obj 
[docs]
    def to_dict(self) -> dict:
        """Dump a QuantConfig instance to a dict.
        Returns:
            dict: The dict dumped from QuantConfig.
        """
        return dataclasses.asdict(self) 
 
@dataclasses.dataclass
class LayerQuantConfig(QuantConfig):
    quant_algo: Optional[QuantConfig] = None
    kv_cache_quant_algo: Optional[QuantConfig] = None
    quantized_layers: Optional[Dict[str, QuantConfig]] = None
    def __init__(self,
                 *,
                 quant_algo: Optional[QuantConfig] = None,
                 kv_cache_quant_algo: Optional[QuantConfig] = None,
                 quantized_layers: Optional[Dict[str, QuantConfig]] = None,
                 **kwargs):
        self.quant_algo = quant_algo
        self.quantized_layers = quantized_layers
        self.kv_cache_quant_algo = kv_cache_quant_algo
        self.auto_quant_mode = {}
        for name, layer_config in self.quantized_layers.items():
            self.auto_quant_mode.update({
                name:
                QuantMode.from_quant_algo(
                    layer_config.quant_algo,
                    self.kv_cache_quant_algo,
                )
            })
        for key in kwargs:
            logger.warning(
                f"Warning: Unrecognized parameter '{key}' with value '{kwargs[key]}'"
            )
    @cached_property
    def quant_mode(self):
        quant_mode_list = list(set(self.auto_quant_mode.values()))
        return QuantModeWrapper(quant_mode_list)
    #@lru_cache(maxsize=None)
    def layer_quant_mode(self, layer_name) -> QuantMode:
        for name, quant_mode in self.auto_quant_mode.items():
            if fnmatch.fnmatch(layer_name, name):
                return quant_mode
        return QuantMode(0)
    @cached_property
    def auto_quant_list(self):
        quant_list = []
        for _, layer_config in self.quantized_layers.items():
            quant_list.append(layer_config.quant_algo)
        return list(set(quant_list))
    @classmethod
    def from_dict(cls, config: dict):
        quantized_layers = config.pop('quantized_layers', {})
        quantized_layers_dict = {
            layer_name: QuantConfig(**layer_config)
            for layer_name, layer_config in quantized_layers.items()
        }
        obj = cls(quantized_layers=quantized_layers_dict, **config)
        return obj
    #@lru_cache(maxsize=None)
    def _get_quant_cfg(self, module_name):
        quant_res = QuantConfig()
        for name, quant_cfg in self.quantized_layers.items():
            if fnmatch.fnmatch(module_name, name):
                quant_res = quant_cfg
                break
        return quant_res
    def _get_modelopt_qformat(self):
        algo_to_modelopt_map = {
            QuantAlgo.NVFP4: "nvfp4",
            QuantAlgo.FP8: "fp8",
            QuantAlgo.W4A16_AWQ: "int4_awq",
            QuantAlgo.W4A8_AWQ: "w4a8_awq",
            QuantAlgo.W8A8_SQ_PER_CHANNEL: "int8_sq",
        }
        assert self.quant_algo == QuantAlgo.MIXED_PRECISION, f"We only support mixed precision quantization in LayerQuantConfig"
        autoq_format = ','.join(
            [algo_to_modelopt_map[item] for item in self.auto_quant_list])
        return autoq_format
    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        output.pop('auto_quant_mode', None)
        output.pop('quant_mode', None)
        for name, per_layer_config in output['quantized_layers'].items():
            per_layer_config = per_layer_config.to_dict()
            output['quantized_layers'][name] = per_layer_config
        return output
[docs]
class PretrainedConfig:
    def __init__(self,
                 *,
                 architecture: str,
                 dtype: str,
                 hidden_size: int,
                 num_hidden_layers: int,
                 num_attention_heads: int,
                 vocab_size: Optional[int] = None,
                 hidden_act: str = 'gelu',
                 logits_dtype: str = 'float32',
                 norm_epsilon: float = 1e-5,
                 position_embedding_type: Union[
                     PositionEmbeddingType,
                     str] = PositionEmbeddingType.learned_absolute,
                 max_position_embeddings: Optional[int] = None,
                 rotary_embedding_dim: Optional[int] = None,
                 num_key_value_heads: Optional[int] = None,
                 intermediate_size: Optional[int] = None,
                 mapping: Optional[Union[Mapping, dict]] = None,
                 quantization: Optional[Union[QuantConfig, dict]] = None,
                 use_parallel_embedding: bool = False,
                 embedding_sharding_dim: int = 0,
                 head_size: Optional[int] = None,
                 qk_layernorm: bool = False,
                 runtime_defaults: "RuntimeDefaultsIn" = None,
                 **kwargs):
        self.architecture = architecture
        self.dtype = dtype
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.logits_dtype = logits_dtype
        self.norm_epsilon = norm_epsilon
        self.runtime_defaults = self.create_runtime_defaults(runtime_defaults)
        if isinstance(position_embedding_type, str):
            position_embedding_type = PositionEmbeddingType.from_string(
                position_embedding_type)
        assert isinstance(position_embedding_type, PositionEmbeddingType)
        self.position_embedding_type = position_embedding_type
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        if intermediate_size is None:
            intermediate_size = hidden_size * 4
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings
        if mapping is None:
            mapping = Mapping()
        elif isinstance(mapping, dict):
            mapping = Mapping.from_dict(mapping)
        assert isinstance(mapping, Mapping)
        self.mapping = mapping
        if quantization is None:
            quantization = QuantConfig()
        elif isinstance(quantization, dict):
            quantization = QuantConfig.from_dict(quantization)
        assert isinstance(quantization, QuantConfig)
        self.quantization = quantization
        self.use_parallel_embedding = use_parallel_embedding
        self.embedding_sharding_dim = embedding_sharding_dim
        if head_size is None:
            head_size = hidden_size // num_attention_heads
        self.head_size = head_size
        self.qk_layernorm = qk_layernorm
        if rotary_embedding_dim is None:
            rotary_embedding_percentage = kwargs.get('rotary_pct', 1.0)
            rotary_embedding_dim = kwargs.get(
                'rotary_dim', int(head_size * rotary_embedding_percentage))
        self.rotary_embedding_dim = rotary_embedding_dim
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
                logger.warning(
                    f"Implicitly setting {self.__class__.__name__}.{key} = {value}"
                )
            except AttributeError as err:
                raise err
[docs]
    @staticmethod
    def create_runtime_defaults(
            defaults: "RuntimeDefaultsIn" = None) -> Optional[RuntimeDefaults]:
        if isinstance(defaults, dict):
            return RuntimeDefaults(**defaults)
        return defaults 
    @property
    def kv_dtype(self):
        # TODO: need to align the kv dtype
        # now assume the kv cache is for all layers
        if self.quant_mode.has_int8_kv_cache():
            return 'int8'
        elif self.quant_mode.has_fp8_kv_cache():
            return 'fp8'
        elif self.quant_mode.has_fp4_kv_cache():
            return 'fp4'
        else:
            return self.dtype
[docs]
    def set_if_not_exist(self, key, value):
        if not hasattr(self, key):
            setattr(self, key, value) 
[docs]
    @classmethod
    def from_dict(cls, config: dict):
        # Maybe we need AutoConfig for this
        from . import MODEL_MAP
        model_cls = MODEL_MAP[config['architecture']]
        config_cls = getattr(model_cls, 'config_class', cls)
        return config_cls(**config) 
[docs]
    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        output['position_embedding_type'] = str(self.position_embedding_type)
        output['mapping'] = self.mapping.to_dict()
        output['mapping'].pop('rank')
        output['quantization'] = self.quantization.to_dict()
        return output 
[docs]
    @classmethod
    def from_json_file(cls, config_file: str):
        with open(config_file) as f:
            config = json.load(f)
        obj = cls.from_dict(config)
        if obj.quantization.quant_algo == QuantAlgo.MIXED_PRECISION:
            try:
                layer_config_path = str(config_file).replace(
                    'config.json', 'quant_cfg.json')
                obj.to_layer_quant_config(layer_config_path)
            except Exception as e:
                raise RuntimeError(
                    f"Encounter error '{e}' for read quantization config '{layer_config_path}'"
                )
        return obj 
[docs]
    @classmethod
    def from_checkpoint(cls, ckpt_dir: str):
        return cls.from_json_file(os.path.join(ckpt_dir, 'config.json')) 
[docs]
    def to_json_file(self, config_file: str):
        with open(config_file, 'w') as f:
            json.dump(self.to_dict(), f, indent=4) 
[docs]
    def to_layer_quant_config(self, config_file: str):
        with open(config_file) as f:
            config = json.load(f)
        if self.architecture == "MixtralForCausalLM":
            for layer_name in list(config["quantized_layers"].keys()):
                quant_cfg = config["quantized_layers"][layer_name]
                if "mlp.fc" in layer_name or "mlp.proj" in layer_name:
                    moe_name, _ = layer_name.rsplit('.', 1)
                    if moe_name not in config["quantized_layers"]:
                        config["quantized_layers"][moe_name] = quant_cfg
                    else:
                        assert quant_cfg == config["quantized_layers"][
                            moe_name], "MoE module needs to have the same quantization format for non-rounter sub-modules"
        self.quantization = LayerQuantConfig.from_dict(config) 
    @property
    def quant_mode(self):
        return self.quantization.quant_mode
    @property
    def quant_algo(self):
        return self.quantization.quant_algo
    def _get_quant_cfg(self, module_name: str):
        return self.quantization._get_quant_cfg(module_name)
[docs]
    def set_rank(self, rank: int):
        self.mapping.rank = rank 
[docs]
    def get_config_group(self, group_cls: "Type[CG]") -> "CG":
        cfg = {k: v for k, v in self.to_dict().items() if k in group_cls.keys()}
        return group_cls(**cfg) 
[docs]
    def has_config_group(self, group_cls: "Type[CG]") -> "bool":
        return all(hasattr(self, key) for key in group_cls.keys()) 
[docs]
    def for_each_rank(self) -> "Generator[Self, None, None]":
        for rank in range(self.mapping.world_size):
            config_copy = copy.deepcopy(self)
            config_copy.set_rank(rank)
            yield config_copy 
 
class DecoderLayerList(ModuleList):
    def __init__(self, cls, config):
        self.num_hidden_layers = config.num_hidden_layers
        self.layer_list = config.mapping.pp_layers(config.num_hidden_layers)
        self.quant_mode = config.quant_mode
        super().__init__([cls(config, idx) for idx in self.layer_list])
    def forward(self,
                hidden_states,
                use_cache=False,
                attention_mask=None,
                kv_cache_params=None,
                attention_params=None,
                mrope_params=None,
                position_ids=None,
                lora_params=None,
                spec_decoding_params=None,
                vision_token_mask=None):
        kv_cache_params.fill_none_tensor_list(len(self.layer_list))
        if use_cache:
            presents = []
        for layer_idx, (layer, past) in enumerate(
                zip(self, kv_cache_params.past_key_value)):
            lora_layer_params = None
            if lora_params is not None and lora_params.lora_ranks is not None:
                lora_layer_params = lora_params.get_layer_params(layer_idx)
            kwargs = {}
            if position_ids is not None:
                kwargs['position_ids'] = position_ids
            if vision_token_mask is not None:
                kwargs['vision_token_mask'] = vision_token_mask
            if lora_layer_params is not None:
                kwargs['lora_layer_params'] = lora_layer_params
            if spec_decoding_params is not None:
                kwargs['spec_decoding_params'] = spec_decoding_params
            if mrope_params is not None:
                kwargs['mrope_params'] = mrope_params
            if default_net().plugin_config.reduce_fusion:
                if layer_idx + self.layer_list[0] < self.layer_list[-1]:
                    qkv_activation_scaling_factor = None
                    if default_net().plugin_config.user_buffer:
                        qkv_linear = self[layer_idx + 1].attention.qkv
                        if self.quant_mode.has_fp8_qdq():
                            qkv_activation_scaling_factor = constant(
                                qkv_linear.activation_scaling_factor.raw_value.
                                copy())
                        elif self.quant_mode.has_nvfp4():
                            qkv_activation_scaling_factor = constant(
                                qkv_linear.activation_global_scaling_factor.
                                raw_value.copy())
                    kwargs['next_layer_input_layernorm_args'] = (
                        self[layer_idx + 1].input_layernorm.weight.value,
                        self[layer_idx + 1].input_layernorm.eps,
                        qkv_activation_scaling_factor)
                else:
                    kwargs['next_layer_input_layernorm_args'] = None
            elif default_net().plugin_config.norm_quant_fusion:
                if layer_idx < self.layer_list[-1] - self.layer_list[0]:
                    try:
                        activation_scaling_factor = constant(
                            self[layer_idx + 1].attention.qkv.
                            activation_global_scaling_factor.raw_value.copy())
                    except:
                        activation_scaling_factor = None
                    kwargs['next_layer_input_layernorm_args'] = (
                        self[layer_idx + 1].input_layernorm.weight.value,
                        self[layer_idx + 1].input_layernorm.eps,
                        activation_scaling_factor)
                else:
                    kwargs['next_layer_input_layernorm_args'] = None
            hidden_states = layer(
                hidden_states,
                use_cache=use_cache,
                attention_mask=attention_mask,
                kv_cache_params=KeyValueCacheParams(
                    past_key_value=[past],
                    host_past_key_value_lengths=kv_cache_params.
                    host_past_key_value_lengths,
                    host_max_attention_window_sizes=kv_cache_params.
                    host_max_attention_window_sizes,
                    host_sink_token_length=kv_cache_params.
                    host_sink_token_length,
                    kv_cache_block_offsets=kv_cache_params.
                    kv_cache_block_offsets,
                    host_kv_cache_block_offsets=kv_cache_params.
                    host_kv_cache_block_offsets,
                    host_kv_cache_pool_pointers=kv_cache_params.
                    host_kv_cache_pool_pointers,
                    host_kv_cache_pool_mapping=kv_cache_params.
                    host_kv_cache_pool_mapping,
                    cache_indirection=kv_cache_params.cache_indirection),
                attention_params=attention_params,
                **kwargs)
            if use_cache:
                presents.append(hidden_states[1])
                hidden_states = hidden_states[0]
        if use_cache:
            return hidden_states, presents
        return hidden_states
class PostInitCaller(type):
    def __call__(cls, *args, **kwargs):
        obj = type.__call__(cls, *args, **kwargs)
        obj.__post_init__()
        return obj
[docs]
class PretrainedModel(Module,
                      GenerationMixin,
                      TopModelMixin,
                      metaclass=PostInitCaller):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        init_all_reduce_helper()
        self.config = config
    def __post_init__(self):
        from ..quantization.quantize import quantize
        quantize(self, self.config.quantization)
        # Currently, use_parallel_embedding must be enabled before weight loading;
        # otherwise, the model will be inconsistent with the weights loaded from checkpoint.
        optimize_model(
            self, use_parallel_embedding=self.config.use_parallel_embedding)
[docs]
    def release(self):
        release_gc() 
    def __del__(self):
        self.release()
[docs]
    def check_config(self, config):
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        ) 
[docs]
    @classmethod
    def from_config(cls, config: PretrainedConfig):
        return cls(config) 
[docs]
    @classmethod
    def from_checkpoint(
        cls,
        ckpt_dir: str,
        rank: Optional[int] = None,
        config: Optional[PretrainedConfig] = None,
        *,
        preprocess_weights_hook: Optional[Callable[[Dict[str, Tensor]],
                                                   Dict[str, Tensor]]] = None):
        if config is None:
            config = PretrainedConfig.from_json_file(
                os.path.join(ckpt_dir, 'config.json'))
        if rank is not None:
            config.set_rank(rank)
        rank = config.mapping.rank
        if config.mapping.auto_parallel:
            rank = 0
        elif config.mapping.cp_size > 1:
            # tp_cp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt
            tp_size = config.mapping.tp_size
            cp_size = config.mapping.cp_size
            rank = rank % tp_size + rank // (tp_size * cp_size) * tp_size
        weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors')
        assert os.path.isfile(weights_path)
        weights = safetensors.torch.load_file(weights_path)
        is_checkpoint_pruned = getattr(config, 'is_pruned', False)
        if preprocess_weights_hook is not None:
            weights = preprocess_weights_hook(weights)
        weights = preprocess_weights(weights,
                                     config,
                                     from_pruned=is_checkpoint_pruned)
        model = cls(config)
        model.load(weights, from_pruned=is_checkpoint_pruned)
        return model 
[docs]
    def load(self, weights, from_pruned=False):
        required_names = set()
        for name, param in self.named_parameters():
            if param.is_inited():
                continue
            if name not in weights:
                # Exemption for embedding sharing
                if name.endswith('lm_head.weight') and any(
                        k.endswith('vocab_embedding.weight')
                        for k in weights.keys()):
                    continue
                if name.endswith('lm_head.per_channel_scale') and any(
                        k.endswith('vocab_embedding.per_channel_scale')
                        for k in weights.keys()):
                    continue
            required_names.add(name)
        provided_names = set(weights.keys())
        if not required_names.issubset(provided_names):
            raise RuntimeError(
                f"Required but not provided tensors:{required_names.difference(provided_names)}"
            )
        if not provided_names.issubset(required_names):
            logger.warning(
                f"Provided but not required tensors: {provided_names.difference(required_names)}"
            )
        for name, param in self.named_parameters():
            if name in provided_names:
                if not from_pruned:
                    try:
                        param.value = weights[name]
                    except Exception as e:
                        raise RuntimeError(
                            f"Encounter error '{e}' for parameter '{name}'")
                else:
                    param.set_value_or_dummy(weights[name]) 
[docs]
    def save_checkpoint(self, output_dir, save_config=True):
        # multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks
        rank = self.config.mapping.rank
        weights = {
            name: numpy_to_torch(param.raw_value)
            for name, param in self.named_parameters()
        }
        # If there are some tensors share memory, this will lead to error when we call "save_file". So, for repeated tensors, we
        # clone the tensors to prevent this issue.
        data_ptrs = set()
        for name, param in weights.items():
            if param.data_ptr() in data_ptrs:
                weights[name] = param.clone()
            data_ptrs.add(weights[name].data_ptr())
        safetensors.torch.save_file(
            weights, os.path.join(output_dir, f'rank{rank}.safetensors'))
        if save_config:
            self.config.to_json_file(os.path.join(output_dir, 'config.json')) 
[docs]
    @classmethod
    def quantize(
        cls,
        hf_model_dir: str,
        output_dir: str,
        dtype: str = 'auto',
        mapping: Optional[Mapping] = None,
        quant_config: Optional[QuantConfig] = None,
        *,
        device: str = 'cuda',
        calib_dataset: str = 'cnn_dailymail',
        calib_batches: int = 512,
        calib_batch_size: int = 1,
        calib_max_seq_length: int = 512,
        random_seed: int = 1234,
        tokenizer_max_seq_length: int = 2048,
        **kwargs,
    ):
        config_cls = getattr(cls, 'config_class', None)
        if config_cls is None:
            raise NotImplementedError(
                f"{cls.__name__} has not implemented corresponding config class, which is needed for correct config parsing."
            )
        config: PretrainedConfig = config_cls.from_hugging_face(
            hf_model_dir,
            dtype=dtype,
            mapping=mapping,
            quant_config=quant_config,
            **kwargs)
        if config.mapping.moe_ep_size > 1:
            raise NotImplementedError(
                "Quantization for expert parallelism is not supported")
        if not config.quantization._requires_modelopt_quantization:
            raise ValueError(
                f"The quant_config ({quant_config}) should not call modelopt quantization"
            )
        from ..quantization import quantize_and_export
        quantize_and_export(
            model_dir=str(hf_model_dir),
            device=device,
            calib_dataset=calib_dataset,
            dtype=config.dtype,
            qformat=config.quantization._get_modelopt_qformat(),
            kv_cache_dtype=config.quantization._get_modelopt_kv_cache_dtype(),
            calib_size=calib_batches,
            batch_size=calib_batch_size,
            calib_max_seq_length=calib_max_seq_length,
            awq_block_size=config.quantization.group_size,
            output_dir=output_dir,
            tp_size=config.mapping.tp_size,
            pp_size=config.mapping.pp_size,
            cp_size=config.mapping.cp_size,
            seed=random_seed,
            tokenizer_max_seq_length=tokenizer_max_seq_length,
        ) 
 
class DecoderModelForCausalLM(PretrainedModel):
    def __init__(self, config: PretrainedConfig, transformer, lm_head):
        super().__init__(config)
        self.transformer = transformer
        self.lm_head = lm_head
        self.mup_width_multiplier = getattr(config, 'mup_width_multiplier',
                                            None)
        # Create constant attention parameters to be reused by all layers.
        Attention.create_attention_const_params(self, config)
        self.position_embedding_type = config.position_embedding_type
    def forward(self,
                input_ids: Tensor,
                position_ids=None,
                use_cache=False,
                last_token_ids=None,
                attention_mask=None,
                kv_cache_params=None,
                attention_params=None,
                mrope_params=None,
                hidden_states=None,
                prompt_embedding_table: Optional[Tensor] = None,
                prompt_tasks: Optional[Tensor] = None,
                prompt_vocab_size: Optional[Tensor] = None,
                lora_params=None,
                spec_decoding_params=None):
        # fill attention params.
        attention_params = Attention.fill_attention_params(
            self, attention_params)
        # split the sequence for context parallelism
        if self.config.mapping.cp_size > 1:
            if len(input_ids.shape) == 1:
                # input shape is [-1]
                input_ids, cp_join_index = cp_split_plugin(
                    input_ids,
                    attention_params.host_request_types,
                    attention_params.host_context_lengths,
                    self.config.mapping.cp_size,
                    self.config.mapping.cp_rank,
                )
            else:
                assert False, "Context parallelism with non-remove-padding is not supported yet."
        is_gemma_2_cg = self.config.has_config_group(Gemma2ConfigGroup)
        is_gemma_3_cg = self.config.has_config_group(Gemma3ConfigGroup)
        kwargs = {
            'input_ids': input_ids,
            'position_ids': position_ids,
            'use_cache': use_cache,
            'attention_mask': attention_mask,
            'kv_cache_params': kv_cache_params,
            'attention_params': attention_params,
        }
        if lora_params is not None:
            kwargs['lora_params'] = lora_params
        if hidden_states is not None:
            kwargs['hidden_states'] = hidden_states
        if prompt_embedding_table is not None:
            kwargs['prompt_embedding_table'] = prompt_embedding_table
        if prompt_tasks is not None:
            kwargs['prompt_tasks'] = prompt_tasks
        if prompt_vocab_size is not None:
            kwargs['prompt_vocab_size'] = prompt_vocab_size
        if spec_decoding_params is not None:
            kwargs['spec_decoding_params'] = spec_decoding_params
        if mrope_params is not None:
            kwargs['mrope_params'] = mrope_params
        hidden_states = self.transformer.forward(**kwargs)
        if use_cache:
            hidden_states, presents = hidden_states
        # All gather and rebuild sequence after transformer layer for context parallelism
        if self.config.mapping.cp_size > 1:
            if len(hidden_states.shape) == 2:
                hidden_states = allgather(hidden_states,
                                          self.config.mapping.cp_group,
                                          gather_dim=0)
                hidden_states = view(hidden_states,
                                     [-1, hidden_states.shape[-1]])
                hidden_states = index_select(hidden_states, 0, cp_join_index)
            else:
                assert False, "Context parallelism with non-remove-padding is not supported yet."
        if self.config.mapping.is_last_pp_rank():
            all_hidden_states = hidden_states
            hidden_states = gather_last_token_logits(
                hidden_states, last_token_ids,
                default_net().plugin_config.remove_input_padding)
            # [batch_size, hidden_size] -> [batch_size, vocab_size]
            lm_logits = self.lm_head(hidden_states)
            if hasattr(self.config, 'output_multiplier_scale'):
                lm_logits *= getattr(self.config, 'output_multiplier_scale', 1)
            if self.mup_width_multiplier is not None:
                lm_logits = lm_logits / self.mup_width_multiplier
            if is_gemma_2_cg or is_gemma_3_cg:
                softcap = self.config.get_config_group(
                    Gemma2ConfigGroup if not is_gemma_3_cg else
                    Gemma3ConfigGroup).final_logit_softcapping
                if softcap:
                    lm_logits = lm_logits * float(1 / softcap)
                    lm_logits = tanh(lm_logits) * float(softcap)
            lm_logits.mark_output('logits', self.config.logits_dtype)
        else:
            hidden_states.mark_output('hidden_states_output', self.config.dtype)
        if use_cache and not default_net().plugin_config.paged_kv_cache:
            for i, present in zip(
                    self.config.mapping.pp_layers(
                        self.config.num_hidden_layers), presents):
                present.mark_output(f'present_key_value_{i}',
                                    self.config.kv_dtype)
            if self.config.mapping.is_last_pp_rank():
                return (lm_logits, presents, hidden_states)
            return (hidden_states, presents)
        else:
            if self.config.mapping.is_last_pp_rank():
                return lm_logits, hidden_states, all_hidden_states
            return hidden_states
def fuse_gate_mlp(
    model: PretrainedModel,
    gemm_swiglu_plugin_dtype: Optional[str] = None,
    low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None,
) -> PretrainedModel:
    from ..quantization.quantize import fp8_quantize
    for name, mlp, layer in model.named_modules_with_parent():
        if isinstance(mlp, GatedMLP):
            init_params = get_init_params(mlp)
            hidden_act = init_params["hidden_act"]
            if hidden_act not in ["silu", "gelu"]:
                logger.warning(
                    f"fuse_gate_mlp cannot be done for {name} due to unsupported activation {hidden_act}. Skipping."
                )
                continue
            init_params["inner_layernorm"] = mlp.inner_layernorm is not None
            fused_layer = FusedGatedMLP(**init_params)
            fc_name = name + '.fc'
            layer_quant_cfg = model.config._get_quant_cfg(fc_name)
            layer_quant_algo = layer_quant_cfg.quant_algo
            if layer_quant_algo != QuantAlgo.FP8 and layer_quant_algo is not None:
                continue
            if isinstance(model.config.quantization.exclude_modules, list) \
                    and fc_name in model.config.quantization.exclude_modules:
                layer_quant_algo = None
            if layer_quant_algo == QuantAlgo.FP8:
                fused_layer = fp8_quantize(fused_layer, layer_quant_cfg)
                if isinstance(mlp.dtype, str):
                    dtype = str_dtype_to_torch(mlp.dtype)
                else:
                    dtype = trt_dtype_to_torch(mlp.dtype)
                gate_weight = numpy_to_torch(mlp.gate.weight.raw_value)
                fc_weight = numpy_to_torch(mlp.fc.weight.raw_value)
                assert gate_weight.dtype == fc_weight.dtype
                need_qdq = gate_weight.dtype == torch.float8_e4m3fn
                gate_weight = gate_weight.to(dtype)
                fc_weight = fc_weight.to(dtype)
                # dequantize if needed
                if need_qdq:
                    gate_weight = gate_weight.to(dtype) * numpy_to_torch(
                        mlp.gate.weights_scaling_factor.raw_value)
                    fc_weight = fc_weight.to(dtype) * numpy_to_torch(
                        mlp.fc.weights_scaling_factor.raw_value)
                # concat
                fused_weight = torch.cat([gate_weight, fc_weight], dim=0)
                fused_weight_scaling_factor = numpy_to_torch(
                    max(
                        mlp.gate.weights_scaling_factor.raw_value,
                        mlp.fc.weights_scaling_factor.raw_value,
                    ))
                # quantize if needed
                if need_qdq:
                    fused_weight = (fused_weight /
                                    fused_weight_scaling_factor).to(
                                        torch.float8_e4m3fn)
                if gemm_swiglu_plugin_dtype == 'fp8' or low_latency_gemm_swiglu_plugin_dtype == 'fp8':
                    # gemm_swiglu_plugin needs (k, n) weights
                    # but weights should still be k-major for fp8
                    fused_layer.fused_fc.weight = Parameter(
                        shape=(fused_layer.fused_fc.in_features,
                               fused_layer.fused_fc.out_features),
                        dtype='fp8')
                    fused_layer.fused_fc.weight.value = fused_weight.view(
                        fused_layer.fused_fc.in_features,
                        fused_layer.fused_fc.out_features)
                else:
                    fused_layer.fused_fc.weight.value = fused_weight
                fused_layer.fused_fc.weights_scaling_factor.value = fused_weight_scaling_factor
                fused_layer.fused_fc.activation_scaling_factor.value = max(
                    mlp.gate.activation_scaling_factor.raw_value,
                    mlp.fc.activation_scaling_factor.raw_value,
                )
            elif layer_quant_algo is None:
                fused_layer.fused_fc.weight.value = np.concatenate(
                    [
                        mlp.gate.weight.raw_value,
                        mlp.fc.weight.raw_value,
                    ],
                    axis=0,
                )
                if mlp.bias:
                    fused_layer.fused_fc.bias.value = np.concatenate(
                        [mlp.gate.bias.raw_value, mlp.fc.bias.raw_value],
                        axis=0)
            else:
                raise ValueError(f'Unsupported quant algo: {layer_quant_algo}')
            fused_layer.proj = mlp.proj
            fused_layer.inner_layernorm = mlp.inner_layernorm
            _, mlp_name = name.rsplit('.', 1)
            setattr(layer, mlp_name, fused_layer)
        elif isinstance(mlp, Fp8RowwiseGatedMLP):
            init_params = get_init_params(mlp)
            hidden_act = init_params["hidden_act"]
            if hidden_act not in ["silu", "gelu"]:
                logger.warning(
                    f"fuse_gate_mlp cannot be done for {name} due to unsupported activation {hidden_act}. Skipping."
                )
                continue
            if mlp.clamp_val is not None:
                init_params["clamp_val"] = mlp.clamp_val.raw_value.tolist()
            fused_layer = Fp8RowwiseFusedGatedMLP(**init_params)
            fused_layer.fused_fc.weight.value = np.concatenate(
                [
                    mlp.gate.weight.raw_value,
                    mlp.fc.weight.raw_value,
                ],
                axis=0,
            )
            fused_layer.fused_fc.per_channel_scale.value = np.concatenate(
                [
                    mlp.gate.per_channel_scale.raw_value,
                    mlp.fc.per_channel_scale.raw_value,
                ],
                axis=0,
            )
            if mlp.bias:
                fused_layer.fused_fc.bias.value = np.concatenate(
                    [mlp.gate.bias.raw_value, mlp.fc.bias.raw_value], axis=0)
            fused_layer.proj = mlp.proj
            _, mlp_name = name.rsplit('.', 1)
            setattr(layer, mlp_name, fused_layer)
    return model
def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel:
    '''Split all the models' Attention layer's QKV GEMM into 3 GEMMs layer.q layer.k, layer.v and return the changed model
    '''
    from ..quantization.quantize import quantize
    for name, layer in model.named_modules():
        if isinstance(layer, Attention) and not layer.cross_attention:
            assert layer.tp_size == 1, "please disable manual tp when enable auto parallel"
            if layer.qkv is None:
                continue
            qkv_params = get_init_params(layer.qkv, ColumnLinear)
            qkv_params["bias"] = qkv_params["bias"] is not None
            qkv_params["strict_dtype"] = qkv_params.get(
                "strict_dtype") is not None
            q = ColumnLinear(
                **{
                    **qkv_params,
                    "out_features":
                    layer.tp_size * layer.num_attention_heads *
                    layer.attention_head_size,
                })
            k = ColumnLinear(
                **{
                    **qkv_params,
                    "out_features":
                    layer.tp_size * layer.num_attention_kv_heads *
                    layer.attention_head_size,
                })
            v = ColumnLinear(
                **{
                    **qkv_params,
                    "out_features":
                    layer.tp_size * layer.num_attention_kv_heads *
                    layer.attention_head_size,
                })
            layer_quant_cfg = model.config._get_quant_cfg(name + '.qkv')
            q = quantize(q, layer_quant_cfg)
            k = quantize(k, layer_quant_cfg)
            v = quantize(v, layer_quant_cfg)
            out_features = q.out_features + k.out_features + v.out_features
            if isinstance(layer.qkv, (
                    WeightOnlyQuantLinear,
                    WeightOnlyQuantRowLinear,
                    WeightOnlyGroupwiseQuantLinear,
                    WeightOnlyGroupwiseQuantRowLinear,
            )):
                out_dim = 1
            else:
                out_dim = 0
            if layer.qkv.weight.is_inited():
                qkv_weight = layer.qkv.weight.raw_value
                weights = np.split(qkv_weight, [
                    qkv_weight.shape[out_dim] * q.out_features // out_features,
                    qkv_weight.shape[out_dim] *
                    (q.out_features + k.out_features) // out_features,
                ],
                                   axis=out_dim)
                for gemm, weight in zip([q, k, v], weights):
                    gemm.weight.value = weight
            if layer.qkv.bias is not None and layer.qkv.bias.is_inited():
                qkv_bias = layer.qkv.bias.raw_value
                biases = np.split(qkv_bias, [
                    qkv_bias.shape[out_dim] * q.out_features // out_features,
                    qkv_bias.shape[out_dim] *
                    (q.out_features + k.out_features) // out_features,
                ],
                                  axis=out_dim)
                for gemm, bias in zip([q, k, v], biases):
                    gemm.bias.value = bias
            for name, parameter in layer.qkv._parameters.items():
                if name not in ["weight", "bias"]:
                    for gemm in [q, k, v]:
                        setattr(gemm, name, parameter)
            layer.q = q
            layer.k = k
            layer.v = v
            layer.qkv = None
    return model
def fuse_rg_lru(model: PretrainedModel) -> PretrainedModel:
    for name, rg_lru, parent in model.named_modules_with_parent():
        if isinstance(rg_lru, RgLru):
            fused_layer = FusedRgLru(**get_init_params(rg_lru))
            fused_layer.gate.weight.value = np.concatenate(
                [
                    rg_lru.input_gate.weight.raw_value,
                    rg_lru.recurrent_gate.weight.raw_value,
                ],
                axis=-1,
            )
            fused_layer.gate.bias.value = np.concatenate(
                [
                    rg_lru.input_gate.bias.raw_value,
                    rg_lru.recurrent_gate.bias.raw_value,
                ],
                axis=-1,
            )
            fused_layer.recurrent_param.value = rg_lru.recurrent_param.raw_value
            rg_lru_name = name.rsplit('.', 1)[-1]
            setattr(parent, rg_lru_name, fused_layer)
    return model
def set_prompt_tuning(model: PretrainedModel) -> PretrainedModel:
    '''Replace the given models embedding layer with a PromptTuningEmbedding layer in-place, return the changed model
       Pre-conditions: vocab_embedding exists
       Post-conditions: isinstance(vocab_embedding, PromptTuningEmbedding)
    '''
    for name, embedding, parent in model.named_modules_with_parent():
        layer_name = name.rsplit('.', 1)[-1]
        if layer_name == "vocab_embedding" and isinstance(embedding, Embedding):
            ptuning_embedding = PromptTuningEmbedding(
                **get_init_params(embedding))
            ptuning_embedding.weight.value = embedding.weight.raw_value
            parent.vocab_embedding = ptuning_embedding
    return model
def add_lora(model: PretrainedModel,
             max_lora_rank: Optional[int],
             with_dora: bool = False) -> PretrainedModel:
    ''' Add lora layers to the Attention/BertAttention/Linear/RowLinear/FusedGatedMLP layers to the given model, return the changed model
    '''
    for name, layer in model.named_modules():
        max_rank = max_lora_rank
        if isinstance(layer, (Attention, BertAttention)):
            if max_rank is None:
                max_rank = min(
                    layer.hidden_size,
                    layer.num_attention_heads * layer.attention_head_size,
                    layer.num_attention_kv_heads * layer.attention_head_size)
            layer.qkv_lora = Lora(
                in_hidden_size=layer.hidden_size,
                out_hidden_sizes=[
                    layer.num_attention_heads * layer.attention_head_size,
                    layer.num_attention_kv_heads * layer.attention_head_size,
                    layer.num_attention_kv_heads * layer.attention_head_size
                ],
                max_low_rank=max_rank,
            )
            if with_dora:
                layer.qkv_dora = Dora(out_hidden_sizes=[
                    layer.num_attention_heads * layer.attention_head_size,
                    layer.num_attention_kv_heads * layer.attention_head_size,
                    layer.num_attention_kv_heads * layer.attention_head_size
                ], )
        if isinstance(layer, (Linear, RowLinear)):
            if max_rank is None:
                max_rank = min(layer.in_features, layer.out_features)
            layer.lora = Lora(
                in_hidden_size=layer.in_features,
                out_hidden_sizes=[layer.out_features],
                max_low_rank=max_rank,
            )
            if with_dora:
                layer.dora = Dora(out_hidden_sizes=[layer.out_features])
        if isinstance(layer, (MLP, FusedGatedMLP)):
            if max_rank is None:
                max_rank = min(layer.hidden_size,
                               layer.ffn_hidden_size // layer.tp_size)
            layer.lora = Lora(
                in_hidden_size=layer.hidden_size,
                out_hidden_sizes=[
                    layer.ffn_hidden_size // layer.tp_size,
                    layer.ffn_hidden_size // layer.tp_size
                ],
                max_low_rank=max_rank,
            )
            if isinstance(layer, FusedGatedMLP):
                layer.fused_gate_up_lora = Lora(
                    in_hidden_size=layer.hidden_size,
                    out_hidden_sizes=[
                        layer.ffn_hidden_size * 2 // layer.tp_size
                    ],
                    max_low_rank=max_rank,
                )
            if with_dora:
                layer.dora = Dora(out_hidden_sizes=[
                    layer.ffn_hidden_size // layer.tp_size,
                    layer.ffn_hidden_size // layer.tp_size
                ], )
                if isinstance(layer, FusedGatedMLP):
                    layer.fused_gate_up_dora = Dora(out_hidden_sizes=[
                        layer.ffn_hidden_size * 2 // layer.tp_size
                    ], )
        if isinstance(layer, MOE):
            if max_rank is None:
                max_rank = min(layer.hidden_size,
                               layer.ffn_hidden_size // layer.tp_size)
            layer.max_low_rank = max_rank
    return model
def to_ootb_moe(model: PretrainedModel) -> PretrainedModel:
    ''' Use OOTB MoE instead of MoE plugin, return the changed model
    '''
    for name, layer, parent in model.named_modules_with_parent():
        if isinstance(layer, MOE):
            layer_name = name.rsplit('.', 1)[-1]
            ootb_layer = layer.to(MoeOOTB, model.config.quantization)
            setattr(parent, layer_name, ootb_layer)
    return model
def parallelize_embedding(model: PretrainedModel) -> PretrainedModel:
    for name, embedding, parent in model.named_modules_with_parent():
        layer_name = name.rsplit('.', 1)[-1]
        if isinstance(embedding, Embedding) and embedding.tp_group is None:
            init_params = get_init_params(embedding)
            init_params["tp_group"] = model.config.mapping.tp_group
            init_params["tp_size"] = model.config.mapping.tp_size
            init_params["tp_rank"] = model.config.mapping.tp_rank
            init_params["sharding_dim"] = model.config.embedding_sharding_dim
            new_embedding = embedding.__class__(**init_params)
            setattr(parent, layer_name, new_embedding)
    return model
def share_embedding(model: PretrainedModel) -> PretrainedModel:
    lm_head = None
    vocab_embedding = None
    for name, layer in model.named_modules():
        layer_name = name.rsplit('.', 1)[-1]
        if layer_name == "lm_head":
            lm_head = layer
        if layer_name == "vocab_embedding":
            vocab_embedding = layer
        if lm_head is not None and vocab_embedding is not None:
            break
    # Cannot find either lm_head or vocab_embedding, e.g., pipeline parallel
    if lm_head is None or vocab_embedding is None:
        return model
    # lm_head and vocab_embedding have different shapes, e.g., tensor parallel without embedding parallel
    if lm_head.weight.shape != vocab_embedding.weight.shape:
        return model
    # lm_head can have a different type if quantized
    if lm_head.weight.dtype != vocab_embedding.weight.dtype:
        return model
    # Don't assume weight can be shared if vocab_embedding is not initialized, e.g., dummy weights
    if not vocab_embedding.weight.is_inited():
        return model
    if lm_head.weight.is_inited():
        lm_head_weight = numpy_to_torch(lm_head.weight.raw_value)
        vocab_embed_weight = numpy_to_torch(vocab_embedding.weight.raw_value)
        # The lm_head and vocab_embedding have different weights
        if (lm_head_weight - vocab_embed_weight).abs().max().item() > 1e-6:
            return model
    lm_head.weight = vocab_embedding.weight
    if getattr(lm_head, 'per_channel_scale', None) and getattr(
            vocab_embedding, 'per_channel_scale', None):
        lm_head.per_channel_scale = vocab_embedding.per_token_scale
    return model
def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel:
    for name, layer in model.named_modules():
        if isinstance(layer, Attention) and hasattr(
                layer.dense, 'activation_scaling_factor'):
            scale = [1.0] / layer.dense.activation_scaling_factor.raw_value
            layer.attention_output_orig_quant_scale = Parameter(
                value=scale.astype(np.float32), dtype='float32')
        elif isinstance(layer, Attention) and hasattr(
                layer.dense, 'activation_global_scaling_factor'):
            scale = [1.0
                     ] / layer.dense.activation_global_scaling_factor.raw_value
            layer.attention_output_orig_quant_scale = Parameter(
                value=scale.astype(np.float32), dtype='float32')
    return model
def set_fuse_fp4_quant(model: PretrainedModel) -> PretrainedModel:
    for name, layer in model.named_modules():
        if isinstance(layer, Attention) and hasattr(
                layer.dense, 'activation_global_scaling_factor'):
            scale = [1.0
                     ] / layer.dense.activation_global_scaling_factor.raw_value
            layer.attention_output_sf_scale = Parameter(value=scale.astype(
                np.float32),
                                                        dtype='float32')
    return model
def optimize_model(
    model: PretrainedModel,
    use_parallel_embedding: bool = False,
    share_embedding_table: bool = False,
    use_ootb_moe: bool = False,
    use_fused_mlp: bool = False,
    gemm_swiglu_plugin_dtype: Optional[str] = None,
    low_latency_gemm_swiglu_plugin_dtype: Optional[str] = None,
    use_fused_rg_lru: bool = False,
    use_unfused_qkv_gemm: bool = False,
    use_prompt_tuning: bool = False,
    use_lora: bool = False,
    max_lora_rank: Optional[int] = None,
    use_fp8_context_fmha: bool = False,
    fuse_fp4_quant: bool = False,
    use_optimize_cross_qkv: bool = False,
    use_dora: bool = False,
) -> PretrainedModel:
    """
    Run optimization passes on model.
    There are dependencies between some passes,
    so we always run passes in the order of arguments to guarantee the execution order.
    """
    # before weight loading
    if use_parallel_embedding:
        model = parallelize_embedding(model)
    if share_embedding_table:
        # if share_embedding_table is enabled, only one copy of the embedding table is store in converted ckpt
        # this pass is required to make lm_head.weight and vocab_embedding.weight point to the same tensor
        # however even if share_embedding_table is not enabled, trt would still only keep one copy of the table if the weights are identical
        model = share_embedding(model)
    # After weight loading
    if use_ootb_moe:
        model = to_ootb_moe(model)
    if use_fused_mlp:
        model = fuse_gate_mlp(model, gemm_swiglu_plugin_dtype,
                              low_latency_gemm_swiglu_plugin_dtype)
    if use_fused_rg_lru:
        model = fuse_rg_lru(model)
    if use_unfused_qkv_gemm:
        model = unfuse_qkv_gemm(model)
    if use_prompt_tuning:
        model = set_prompt_tuning(model)
    if use_lora:
        model = add_lora(model, max_lora_rank, with_dora=use_dora)
    if use_fp8_context_fmha:
        model = set_fp8_context_fhma(model)
    if fuse_fp4_quant:
        model = set_fuse_fp4_quant(model)
    if not use_lora and use_optimize_cross_qkv is True:
        # This optimization is not supported when we use lora
        model = optimize_cross_qkv(model)
    return model
def optimize_cross_qkv(model):
    """
    For cross attention layer, we can skip computing the query of encoder_output.
    So, add a new attribute 'kv' in the cross_attention layer. This might lead to
    additional memory cost on model size, but save the memory usage on runtime.
    Currently, this function only detect the ColumnLinear and FP8Linear. It does not supports
    other quantization now.
    """
    for name, attn, layer in model.named_modules_with_parent():
        if isinstance(attn, Attention) and attn.cross_attention and \
        (type(attn.qkv) == ColumnLinear or type(attn.qkv) == FP8Linear):
            old_qkv = attn.qkv
            linear_class = type(old_qkv)
            new_kv = linear_class(
                in_features=attn.hidden_size,
                out_features=2 * attn.tp_size * attn.num_attention_kv_heads *
                attn.attention_head_size,
                bias=old_qkv.bias,
                dtype=old_qkv.dtype,
                tp_group=old_qkv.tp_group,
                tp_size=old_qkv.tp_size,
                gather_output=old_qkv.gather_output,
                prefer_managed_weight=old_qkv.prefer_managed_weight,
                is_qkv=old_qkv.is_qkv,
            )
            old_qkv_weight_value = old_qkv.weight.raw_value
            if (old_qkv_weight_value.shape == np.asarray([
                (attn.num_attention_heads + 2 * attn.num_attention_kv_heads) *
                    attn.attention_head_size, attn.hidden_size
            ])).all():
                q_weight, kv_weight = np.array_split(
                    old_qkv_weight_value.reshape(
                        attn.num_attention_heads +
                        2 * attn.num_attention_kv_heads,
                        attn.attention_head_size, attn.hidden_size),
                    [attn.num_attention_heads],
                    axis=0)
                new_kv.weight.value = kv_weight.reshape([
                    2 * attn.num_attention_kv_heads * attn.attention_head_size,
                    attn.hidden_size
                ])
            elif (old_qkv_weight_value.shape == np.asarray([
                    attn.hidden_size,
                (attn.num_attention_heads + 2 * attn.num_attention_kv_heads) *
                    attn.attention_head_size
            ])).all():
                q_weight, kv_weight = np.array_split(
                    old_qkv_weight_value.reshape(
                        attn.hidden_size, attn.num_attention_heads +
                        2 * attn.num_attention_kv_heads,
                        attn.attention_head_size), [attn.num_attention_heads],
                    axis=1)
                new_kv.weight.value = kv_weight.reshape([
                    attn.hidden_size,
                    2 * attn.num_attention_kv_heads * attn.attention_head_size
                ])
            else:
                assert False
            if isinstance(attn.qkv, FP8Linear):
                new_kv.activation_scaling_factor.value = old_qkv.activation_scaling_factor.raw_value
                new_kv.weights_scaling_factor.value = old_qkv.weights_scaling_factor.raw_value
            if old_qkv.bias:
                q_bias, kv_bias = np.array_split(old_qkv.bias.raw_value.reshape(
                    attn.num_attention_heads + 2 * attn.num_attention_kv_heads,
                    attn.attention_head_size), [attn.num_attention_heads],
                                                 axis=0)
                new_kv.bias.value = kv_bias.reshape([
                    2 * attn.num_attention_kv_heads * attn.attention_head_size
                ])
            setattr(attn, "kv", new_kv)
    return model
def preprocess_perlayer_weights(weights,
                                model_config,
                                quant_algo,
                                from_pruned=False):
    exclude_modules = model_config.quantization.exclude_modules
    # INT4_AWQ
    if quant_algo == QuantAlgo.W4A8_AWQ or quant_algo == QuantAlgo.W4A16_AWQ:
        preprocessor = preprocess_weights_for_mixed_gemm
        if quant_algo == QuantAlgo.W4A8_AWQ:
            activation_type = torch.float8_e4m3fn
        elif quant_algo == QuantAlgo.W4A16_AWQ:
            activation_type = torch.float16
        for name, param in weights.items():
            if from_pruned and param.numel() == 0:
                continue
            if name.endswith('weight') and param.dtype == torch.int8:
                dtype = torch.float16
                if model_config.dtype == "bfloat16":
                    dtype = torch.bfloat16
                weights[name] = preprocessor(param.transpose(-1, -2),
                                             torch.quint4x2,
                                             activation_type).view(dtype)
            if name.endswith('weights_scaling_factor'):
                weights[name] = param.transpose(-1, -2).contiguous().to(
                    str_dtype_to_torch(model_config.dtype))
            if name.endswith('prequant_scaling_factor'):
                if len(weights[name].shape) == 2:
                    # MoE experts share the same scaling factor.
                    param = param[0, :]
                weights[name] = param.reshape(1, -1)
            if model_config.mapping.tp_rank > 0:
                if name.endswith('attention.dense.bias') or name.endswith(
                        'mlp.proj.bias'):
                    weights[name] = torch.zeros_like(param)
        if quant_algo == QuantAlgo.W4A8_AWQ:
            for name in list(weights):
                if name.endswith('weights_scaling_factor'):
                    activation_scaling_factor = weights.pop(
                        name.replace('weights_scaling_factor',
                                     'activation_scaling_factor'))
                    weights_scaling_factor_2 = weights.pop(
                        name.replace('weights_scaling_factor',
                                     'weights_scaling_factor_2'))
                    weights[name] /= weights_scaling_factor_2
                    weights[name] = weights[name].to(torch.float16).view(
                        str_dtype_to_torch(model_config.dtype))
                    weights[name.replace(
                        'weights_scaling_factor',
                        'prequant_scaling_factor')] /= activation_scaling_factor
                    weights[name.replace(
                        'weights_scaling_factor', 'alpha'
                    )] = activation_scaling_factor * weights_scaling_factor_2
                    weights[name.replace('weights_scaling_factor',
                                         'activation_scaling_factor'
                                         )] = activation_scaling_factor
    # FP8
    elif quant_algo == QuantAlgo.FP8:
        for name, param in weights.items():
            if name.endswith('weight') and param.dtype == torch.int8:
                weights[name] = param.view(torch.float8_e4m3fn)
        # lm_head is not always quantized to FP8
        if "lm_head.weight" in weights and weights[
                'lm_head.weight'].dtype is not torch.float8_e4m3fn:
            weights.pop('lm_head.weights_scaling_factor', None)
            weights.pop('lm_head.activation_scaling_factor', None)
    elif quant_algo == QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN:
        for name, param in weights.items():
            if name.endswith('weight') and param.dtype == torch.int8:
                weights[name] = param.view(torch.float8_e4m3fn)
        # lm_head is not quantized to FP8
        if "lm_head.weight" in weights:
            assert weights['lm_head.weight'].dtype == str_dtype_to_torch(
                model_config.dtype)
            weights.pop('lm_head.weights_scaling_factor', None)
            weights.pop('lm_head.activation_scaling_factor', None)
    # FP4
    elif quant_algo == QuantAlgo.NVFP4:
        # Interleave block scale for NVFP4 plugin.
        for name in list(weights):
            if name.endswith('weights_scaling_factor'):
                out_features, in_features = weights[name].shape
                nrows = fp4_utils.pad_up(out_features, 128)
                ncols = fp4_utils.pad_up(in_features, 4)
                new_name = name.replace('weights_scaling_factor',
                                        'weights_block_scaling_factor')
                weights[new_name] = weights[name]
                weights[
                    new_name +
                    "_interleaved"] = torch.ops.trtllm.nvfp4_block_scale_interleave(
                        weights[name].view(fp4_utils.float4_sf_dtype).cpu(
                        ).contiguous()).reshape(nrows, ncols).view(
                            fp4_utils.float4_sf_dtype)
                weights.pop(name)
            if name.endswith('weights_scaling_factor_2'):
                new_name = name.replace('weights_scaling_factor_2',
                                        'weights_global_scaling_factor')
                weights[new_name] = weights[name]
                weights.pop(name)
            if name.endswith('activation_scaling_factor'):
                new_name = name.replace('activation_scaling_factor',
                                        'activation_global_scaling_factor')
                weights[new_name] = weights[name]
                weights.pop(name)
        for name in list(weights):
            if name.endswith('weights_global_scaling_factor'):
                weight_global_sf = weights[name]
                act_global_sf = weights[name.replace(
                    'weights_global_scaling_factor',
                    'activation_global_scaling_factor')]
                weights[name.replace(
                    'weights_global_scaling_factor',
                    'alpha')] = act_global_sf * weight_global_sf
    elif quant_algo in [QuantAlgo.W4A16, QuantAlgo.W8A16]:
        weights = weight_only_quantize_dict(weights=weights,
                                            quant_algo=quant_algo,
                                            exclude_modules=exclude_modules,
                                            plugin=True)
def preprocess_weights(weights: Dict[str, torch.Tensor],
                       model_config: PretrainedConfig,
                       from_pruned=False) -> None:
    """This function in-place modifies weights and model_config, making them compatible with each other.
    Note: Typically, it should be called before model creation and weight loading. For example,
        preprocess_weights(weights, model_config)
        model = XXXForCausalLM(model_config)
        model.load(weights)
    """
    quant_config = model_config.quantization
    quant_algo = quant_config.quant_algo
    pattern_info = ['fc', 'gate', 'proj', 'qkv', 'dense']
    def process_kv_scaling_factor(weights: Dict[str, torch.Tensor]):
        new_entries = {}
        names_to_delete = set()
        # If k, v cache scaling factors are stored separately, combine them into kv cache scaling factor.
        for name, param in weights.items():
            if name.endswith('.k_cache_scaling_factor'):
                v_name = name.replace('k_cache_scaling_factor',
                                      'v_cache_scaling_factor')
                assert v_name in weights, f"{v_name} not found"
                kv_name = name.replace('k_cache_scaling_factor',
                                       'kv_cache_scaling_factor')
                new_entries[kv_name] = torch.max(weights[name], weights[v_name])
                names_to_delete.update([name, v_name])
        weights.update(new_entries)
        for k in names_to_delete:
            del weights[k]
        new_entries = []
        # The unified converter generate_tllm_weights() already generates these rcp weights, but legacy
        # converters do not. Handle it here.
        for name, param in weights.items():
            if name.endswith('.kv_cache_scaling_factor'):
                rcp_name = name.replace('kv_cache_scaling_factor',
                                        'kv_cache_rcp_scaling_factor')
                if rcp_name not in weights:
                    new_entries.append((rcp_name, torch.reciprocal(param)))
        weights.update(new_entries)
    process_kv_scaling_factor(weights)
    per_layer_weights = {}
    for name, param in weights.items():
        in_mode = False
        for info in pattern_info:
            pattern = rf'(.*?{info}.*?)'
            pattern_match = re.match(pattern, name)
            if pattern_match:
                base_name = pattern_match.group(1)
                if base_name not in per_layer_weights.keys():
                    per_layer_weights[base_name] = {}
                per_layer_weights[base_name][name] = param
                in_mode = True
                break
        if not in_mode:
            # [lm_head.weight, ln_f.weight, vocab_embedding.weight]
            base_name = name.rsplit('.', 1)[0]
            if base_name not in per_layer_weights.keys():
                per_layer_weights[base_name] = {}
            per_layer_weights[base_name][name] = param
    new_weights = {}
    for base_name, layer_weights in per_layer_weights.items():
        if quant_algo != QuantAlgo.MIXED_PRECISION:
            layer_quant_algo = quant_algo
        else:
            quant_cfg = quant_config._get_quant_cfg(base_name)
            if not quant_cfg.quant_algo:
                new_weights.update(layer_weights)
                continue
            layer_quant_algo = quant_cfg.quant_algo
        preprocess_perlayer_weights(layer_weights, model_config,
                                    layer_quant_algo, from_pruned)
        new_weights.update(layer_weights)
    weights = new_weights
    for name, param in weights.items():
        if model_config.architecture == 'GPTJForCausalLM':
            if model_config.mapping.tp_rank > 0:
                if 'attention.dense.bias' in name or 'mlp.proj.bias' in name:
                    weights[name] = torch.zeros_like(param)
    return weights
def get_kv_cache_type_from_legacy(use_cache: bool,
                                  paged_kv_cache: bool) -> KVCacheType:
    if use_cache:
        if paged_kv_cache:
            return KVCacheType.PAGED
        else:
            return KVCacheType.CONTINUOUS
    else:
        return KVCacheType.DISABLED
def save_config(config: PretrainedConfig, *, output_dir: str,
                log: bool) -> None:
    config_path = Path(output_dir) / "config.json"
    if log:
        logger.debug(f"Saving TensorRT LLM configuration to {config_path}")
    config_path.parent.mkdir(exist_ok=True, parents=True)
    config_path.write_text(json.dumps(config.to_dict(), indent=4))
def save_checkpoint(*, output_dir: str, weights: dict, rank: int) -> None:
    """ Checkpoint saver for weight loader."""
    safetensors.torch.save_file(
        weights, os.path.join(output_dir, f'rank{rank}.safetensors'))