# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from examples/quantization/hf_ptq.py
"""
import copy
import json
import os
import random
import sys
import time
import numpy as np
import safetensors
import torch
from accelerate.hooks import remove_hook_from_module
from datasets import load_dataset
from safetensors.torch import load_file, save_file
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from .._utils import release_gc, str_dtype_to_torch
from ..logger import logger
from ..mapping import Mapping
from .mode import QuantAlgo
EMPTY_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"enable": False,
},
"*input_quantizer": {
"enable": False
},
"*lm_head*": {
"enable": False
},
"*output_layer*": {
"enable": False
},
"default": {
"enable": False
},
},
"algorithm": "max",
}
KV_CACHE_CFG = {
"*.query_key_value.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.Wqkv.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.W_pack.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.c_attn.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.k_proj.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
"*.v_proj.output_quantizer": {
"num_bits": 8,
"axis": None,
"enable": True
},
}
def quant_cfg_choices():
import modelopt.torch.quantization as mtq
QUANT_CFG_CHOICES = {
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"int8_wo": EMPTY_CFG,
"int4_wo": EMPTY_CFG,
"full_prec": EMPTY_CFG,
}
return QUANT_CFG_CHOICES
MODEL_NAME_PATTERN_MAP = {
"GPT2": "gpt2",
"Xverse": "llama",
"Llama": "llama",
"Mistral": "llama",
"GPTJ": "gptj",
"FalconForCausalLM": "falcon",
"RWForCausalLM": "falcon",
"baichuan": "baichuan",
"MPT": "mpt",
"Bloom": "bloom",
"ChatGLM": "chatglm",
"QWen": "qwen",
"Gemma2": "gemma2",
"Gemma": "gemma",
"MixtralForCausalLM": "llama",
"NemotronForCausalLM": "nemotron",
"GPTBigCodeForCausalLM": "gpt_bigcode",
"ArcticForCausalLM": "llama",
"Phi3SmallForCausalLM": "phi3small",
"Phi3ForCausalLM": "phi3",
"Starcoder2ForCausalLM": "gptnext",
"GPTBigCodeForCausalLM": "gptnext",
"GLM": "glm",
"Exaone": "exaone",
"DeciLMForCausalLM": "deci",
"DeepseekForCausalLM": "deepseek",
}
class _CustomDataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __getitem__(self, idx):
item = {
key: val[idx].clone().detach().requires_grad_(False)
for key, val in self.encodings.items()
}
return item
def __len__(self):
return len(self.encodings["input_ids"])
def get_tokenizer(ckpt_path, max_seq_length=2048, model_type=None):
logger.info(f"Initializing tokenizer from {ckpt_path}")
tokenizer = AutoTokenizer.from_pretrained(
ckpt_path,
model_max_length=max_seq_length,
padding_side="left",
trust_remote_code=True,
)
if tokenizer.pad_token is None:
if model_type and model_type == "qwen":
# qwen use token id 151643 as pad and eos tokens
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
else:
tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.pad_token is not None, f"Pad token for {model_type} cannot be set!"
return tokenizer
def _get_vila_model(model_dir):
sys.path.append(model_dir + "/../VILA")
from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa
from transformers import AutoModel
model = AutoModel.from_pretrained(
model_dir,
device_map='auto',
trust_remote_code=True,
)
return model.llm
def get_hf_config(ckpt_path):
if "mpt" in ckpt_path:
# MPT-7B cannot get initialized from AutoConfig
from transformers import MptConfig
return MptConfig.from_pretrained(ckpt_path)
else:
return AutoConfig.from_pretrained(ckpt_path, trust_remote_code=True)
def _get_llava_qwen_model(model_dir, dtype, device):
if "hf" in model_dir:
from transformers import LlavaOnevisionForConditionalGeneration
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_dir, torch_dtype=dtype, device_map=device)
model = model.language_model
else:
from llava.model.builder import load_pretrained_model
_, model, _, _ = load_pretrained_model(model_dir,
None,
'llava_qwen',
torch_dtype=dtype,
device_map=device)
return model
def get_model(ckpt_path: str, dtype: str = 'bfloat16', device: str = 'cuda'):
logger.info(f"Initializing model from {ckpt_path}")
# Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo
hf_config = get_hf_config(ckpt_path)
torch_dtype = str_dtype_to_torch(dtype)
model_cls = AutoModelForCausalLM
if hf_config.model_type == "llava":
from transformers import LlavaForConditionalGeneration
model_cls = LlavaForConditionalGeneration
elif hf_config.model_type == "mpt":
from transformers import MptForCausalLM
model_cls = MptForCausalLM
if "vila" in ckpt_path:
model = _get_vila_model(ckpt_path)
elif "llava-onevision-qwen2" in ckpt_path:
model = _get_llava_qwen_model(ckpt_path, dtype, device)
elif hf_config.model_type == "glm":
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_path,
device_map="cuda",
torch_dtype=torch_dtype,
trust_remote_code=True)
else:
model = model_cls.from_pretrained(
ckpt_path,
device_map="auto" if device != "cpu" else "cpu",
torch_dtype="auto",
trust_remote_code=True)
if hf_config.model_type in ["llava", "internvl_chat"]:
model = model.language_model
model.eval()
model_dtype = next(model.parameters()).dtype
if torch_dtype != model_dtype:
logger.info(
f"[TensorRT-LLM][WARNING] The manually set model data type is {dtype}, "
f"but the data type of the HuggingFace model is {model_dtype}.")
return model
def get_model_type(model):
for k, v in MODEL_NAME_PATTERN_MAP.items():
if k.lower() in type(model).__name__.lower():
return v
return None
def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
tokenizer=None,
batch_size=1,
calib_size=512,
block_size=512,
device=None,
include_labels=False):
logger.info("Loading calibration dataset")
if dataset_name_or_dir == "pileval":
dataset = load_dataset(
"json",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
dataset = dataset["text"][:calib_size]
elif "cnn_dailymail" in dataset_name_or_dir:
dataset = load_dataset(
dataset_name_or_dir,
name="3.0.0",
split="train",
)
dataset = dataset["article"][:calib_size]
elif os.path.isdir(dataset_name_or_dir):
logger.info(
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
"assuming the calibration data are in the train split and text column."
)
dataset = load_dataset(dataset_name_or_dir, split="train")
dataset = dataset["text"][:calib_size]
else:
raise NotImplementedError(
f"Unsupported dataset name or local repo directory: {dataset_name_or_dir}."
)
batch_encoded = tokenizer.batch_encode_plus(dataset,
return_tensors="pt",
padding=True,
truncation=True,
max_length=block_size)
if device:
batch_encoded = batch_encoded.to(device)
if include_labels:
# Labels are needed when backward is called in the model.
# The labels should be a shifted version of the input_ids.
# However, we should not shift the input_ids here since the labels are shifted by
# Huggingface models during loss calculation as shown here -
# https://github.com/huggingface/transformers/blob/7f79a97399bb52aad8460e1da2f36577d5dccfed/src/transformers/models/llama/modeling_llama.py#L1093-L1095
batch_encoded["labels"] = torch.where(
batch_encoded["attention_mask"] > 0.5, batch_encoded["input_ids"],
-100)
batch_encoded = _CustomDataset(batch_encoded)
else:
# For backward compatibility, if labels are not needed, we only return input_ids.
batch_encoded = _CustomDataset(
{"input_ids": batch_encoded["input_ids"]})
calib_dataloader = DataLoader(batch_encoded,
batch_size=batch_size,
shuffle=False)
return calib_dataloader
def quantize_model(model, quant_cfg, calib_dataloader, batch_size, qformat,
auto_quantize_bits):
import modelopt.torch.quantization as mtq
# NOTE: for ModelOpt v0.19 release
# calibrate_loop = dataset_utils.create_forward_loop(
# calib_dataloader, dataloader=calib_dataloader)
def calibrate_loop():
if calib_dataloader is None:
return
with torch.no_grad():
low_mem_mode = False
for _, data in enumerate(calib_dataloader):
batch_size = data[list(data.keys())[0]].shape[0]
if batch_size == 1:
model(**data)
elif not low_mem_mode:
# Try running the forward once.
# If output memory, we try running inference with split input tensors
try:
model(**data)
except torch.OutOfMemoryError:
print(
"Warning: torch.OutOfMemoryError detected, try reducing the batch size..."
)
low_mem_mode = True
if low_mem_mode:
split_data_1 = {
key: data[key][:batch_size // 2, ...]
for key in data
}
model(**split_data_1)
split_data_2 = {
key: data[key][batch_size // 2:, ...]
for key in data
}
model(**split_data_2)
QUANT_CFG_CHOICES = {
"int8": "INT8_DEFAULT_CFG",
"int8_sq": "INT8_SMOOTHQUANT_CFG",
"fp8": "FP8_DEFAULT_CFG",
"int4_awq": "INT4_AWQ_CFG",
"w4a8_awq": "W4A8_AWQ_BETA_CFG",
}
logger.info("Starting quantization...")
start_time = time.time()
if auto_quantize_bits:
logger.info("Starting mixed precision quantization...")
model, search_history = mtq.auto_quantize(
model,
data_loader=calib_dataloader,
loss_func=lambda output, batch: output.loss,
constraints={"effective_bits": auto_quantize_bits},
forward_step=lambda model, batch: model(**batch),
quantization_formats=[
QUANT_CFG_CHOICES[item] for item in qformat.split(",")
] + [None],
num_calib_steps=len(calib_dataloader),
num_score_steps=min(
len(calib_dataloader), 128 // batch_size
), # Limit the number of score steps to avoid long calibration time
verbose=True,
)
mtq.print_quant_summary(model)
# We need to explicitly calibrate for kv cache quantization
enable_kv_cache_quantization = "int8" not in qformat
if enable_kv_cache_quantization:
mtq.set_quantizer_by_cfg(
model,
quant_cfg={
"*output_quantizer": {
"num_bits": (4, 3),
"axis": None,
"enable": True
}
},
)
# Lets calibrate only the output quantizer this time. Let's disable all other quantizers.
with mtq.set_quantizer_by_cfg_context(model, {
"*": {
"enable": False
},
"*output_quantizer": {
"enable": True
}
}):
mtq.calibrate(model,
algorithm="max",
forward_loop=calibrate_loop)
else:
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
end_time = time.time()
logger.info(
"Quantization done. Total time used: {:.2f} s.".format(end_time -
start_time))
return model
def combine_medusa_weight(tp_size, pp_size, base_model_output_dir,
num_medusa_heads, num_medusa_layers, max_draft_len,
medusa_hidden_act, medusa_model_dir,
quant_medusa_head):
with open(f"{medusa_model_dir}/config.json", "r") as fp:
medusa_config = json.load(fp)
num_medusa_heads_from_config = medusa_config.get('medusa_num_heads',
num_medusa_heads)
num_medusa_layers = medusa_config.get('medusa_num_layers',
num_medusa_layers)
if num_medusa_heads is None:
num_medusa_heads = num_medusa_heads_from_config
assert max_draft_len > 0, "should have max_draft_len > 0"
world_size = tp_size * pp_size
# Process for each rank
for rank in range(world_size):
mapping = Mapping(world_size=world_size,
rank=rank,
tp_size=tp_size,
pp_size=pp_size)
# 1. Load medusa weight for each rank
from tensorrt_llm.models.medusa.weight import load_medusa_hf
medusa_weights = load_medusa_hf(medusa_path=medusa_model_dir,
num_medusa_heads=num_medusa_heads,
num_medusa_layers=num_medusa_layers,
mapping=mapping,
dtype="float16")
# 2. Load base model safetensors (after quant)
base_model_weights = load_file(
f"{base_model_output_dir}/rank{rank}.safetensors")
# 3. Combine and save weight
base_model_weights.update(medusa_weights)
save_file(base_model_weights,
f"{base_model_output_dir}/rank{rank}.safetensors")
# 4. Add medusa config into config.json
with open(f"{base_model_output_dir}/config.json", 'r') as f:
base_model_config = json.load(f)
f.close()
with open(f"{base_model_output_dir}/config.json", 'w') as f:
base_model_config['architecture'] = "MedusaForCausalLM"
base_model_config['quantization']['exclude_modules'] = [
'lm_head',
'*router',
'*vocab_embedding',
'*position_embedding',
'*block_embedding',
]
if not quant_medusa_head:
base_model_config['quantization']['exclude_modules'].append(
'*medusa_heads*')
base_model_config['max_draft_len'] = max_draft_len
base_model_config['num_medusa_heads'] = num_medusa_heads
base_model_config['num_medusa_layers'] = num_medusa_layers
json.dump(base_model_config, f, indent=4)
torch.cuda.empty_cache()
logger.info("Combine medusa heads' weight, done.")
[docs]
def quantize_and_export(*,
model_dir,
device,
calib_dataset,
dtype,
qformat,
kv_cache_dtype,
calib_size,
batch_size,
calib_max_seq_length,
awq_block_size,
output_dir,
tp_size,
pp_size,
seed,
tokenizer_max_seq_length,
num_medusa_heads=None,
num_medusa_layers=None,
max_draft_len=None,
medusa_hidden_act=None,
medusa_model_dir=None,
quant_medusa_head=None,
auto_quantize_bits=None):
'''
Load model from the model_dir, call Modelopt to quantize the model, and then export
the quantized model as TRT-LLM checkpoint
'''
try:
import modelopt # noqa
except ImportError as e:
logger.error(
"Failed to import modelopt, pls check the Modelopt installation. Currently it is known to be unsupported on Windows OS"
)
raise e
from modelopt.torch.export import export_tensorrt_llm_checkpoint
from tensorrt_llm.models.convert_utils import infer_dtype
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for inference.")
random.seed(seed)
np.random.seed(seed)
# Check that only one quantization format is provided for non auto_quant case
if not auto_quantize_bits:
assert (len(qformat.split(",")) == 1
), "Quantization supports only one quantization format."
hf_config = get_hf_config(model_dir)
dtype = infer_dtype(dtype, getattr(hf_config, 'torch_dtype', None))
model = get_model(model_dir, dtype, device=device)
model_type = get_model_type(model)
if "vila" in model_dir:
tokenizer = get_tokenizer(model_dir + "/llm",
max_seq_length=tokenizer_max_seq_length,
model_type=model_type)
else:
tokenizer = get_tokenizer(model_dir,
max_seq_length=tokenizer_max_seq_length,
model_type=model_type)
if qformat in ["full_prec", "int8_wo", "int4_wo"
] and kv_cache_dtype is None:
logger.info(f"No quantization applied, export {dtype} model")
else:
if "awq" in qformat:
if calib_size > 32:
logger.info(
f"AWQ calibration could take longer with calib_size = {calib_size}, Using"
" calib_size=32 instead")
calib_size = 32
logger.info(
"\nAWQ calibration could take longer than other calibration methods. Please"
" increase the batch size to speed up the calibration process. Batch size can be"
" set by adding the argument --batch_size <batch_size> to the command line.\n"
)
quant_cfg = None
if not auto_quantize_bits:
if qformat in quant_cfg_choices():
quant_cfg = quant_cfg_choices()[qformat]
else:
raise ValueError(f"Unsupported quantization format: {qformat}")
if "awq" in qformat:
quant_cfg = copy.deepcopy(quant_cfg_choices()[qformat])
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
if awq_block_size:
weight_quantizer["block_sizes"][-1] = awq_block_size
# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if "w4a8_awq" == qformat and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {
"method": "awq_lite",
"alpha_step": 1
}
if kv_cache_dtype is not None:
if kv_cache_dtype == "fp8":
for value in KV_CACHE_CFG.values():
value.update({"num_bits": (4, 3)}) # type: ignore
quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
if model_type == "gemma" and "int8_sq" in qformat:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
calib_dataloader = get_calib_dataloader(
dataset_name_or_dir=calib_dataset,
tokenizer=tokenizer,
batch_size=batch_size,
calib_size=calib_size,
block_size=calib_max_seq_length,
device=device,
include_labels=auto_quantize_bits is not None,
)
model = quantize_model(model, quant_cfg, calib_dataloader, batch_size,
qformat, auto_quantize_bits)
with torch.inference_mode():
if model_type is None:
logger.info(
f"Unknown model type {type(model).__name__}. Continue exporting..."
)
model_type = f"unknown:{type(model).__name__}"
export_path = output_dir
start_time = time.time()
# Move meta tensor back to device before exporting.
remove_hook_from_module(model, recurse=True)
QUANT_ALGO = {
"int8": "INT8",
"int8_sq": "W8A8_SQ_PER_CHANNEL",
"fp8": "FP8",
"int4_awq": "W4A16_AWQ",
"w4a8_awq": "W4A8_AWQ",
}
export_tensorrt_llm_checkpoint(
model,
model_type,
getattr(torch, dtype),
export_dir=export_path,
inference_tensor_parallel=tp_size,
inference_pipeline_parallel=pp_size,
)
with open(f"{export_path}/config.json", "r") as f:
tensorrt_llm_config = json.load(f)
tensorrt_llm_config["model_type"] = model_type
# Workaround for wo quantization
if qformat in ["int8_wo", "int4_wo", "full_prec"]:
if qformat == "int8_wo":
tensorrt_llm_config["quantization"][
"quant_algo"] = QuantAlgo.W8A16
elif qformat == "int4_wo":
tensorrt_llm_config["quantization"][
"quant_algo"] = QuantAlgo.W4A16
else:
tensorrt_llm_config["quantization"]["quant_algo"] = None
# HF uses rope_scaling while tensorrt_llm uses rotary_scaling
if hasattr(
model.config,
"rope_scaling") and "rotary_scaling" not in tensorrt_llm_config:
tensorrt_llm_config["rotary_scaling"] = getattr(
model.config, "rope_scaling")
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)
# Workaround for Modelopt 0.9.x fp8_kv_cache knob issue
if qformat == 'fp8' and kv_cache_dtype is None:
with open(f"{export_path}/config.json", "r") as f:
tensorrt_llm_config = json.load(f)
tensorrt_llm_config["quantization"]["kv_cache_quant_algo"] = None
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)
# Workaround for share_embedding_table
if pp_size == 1:
with safetensors.safe_open(f"{export_path}/rank0.safetensors",
framework='pt',
device='cpu') as f:
share_embedding_table = 'lm_head.weight' not in f.keys()
if share_embedding_table:
with open(f"{export_path}/config.json", "r") as f:
tensorrt_llm_config = json.load(f)
tensorrt_llm_config["share_embedding_table"] = True
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)
# Workaround for qwen version
if model_type == 'qwen':
with open(f"{export_path}/config.json", "r") as f:
tensorrt_llm_config = json.load(f)
qwen_config = AutoConfig.from_pretrained(model_dir,
trust_remote_code=True)
try:
from transformers import LlavaOnevisionConfig
if isinstance(qwen_config, LlavaOnevisionConfig):
qwen_config = qwen_config.text_config
except:
pass
tensorrt_llm_config["qwen_type"] = qwen_config.model_type
if qwen_config.model_type == "qwen2":
tensorrt_llm_config["norm_epsilon"] = qwen_config.rms_norm_eps
tensorrt_llm_config["rotary_base"] = qwen_config.rope_theta
tensorrt_llm_config[
"intermediate_size"] = qwen_config.intermediate_size
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)
# Set rotary parameters correctly for chatglm.
if model_type == 'chatglm':
rotary_base = 10000.0
rotary_embedding_scaling = None
chatglm_config = AutoConfig.from_pretrained(model_dir,
trust_remote_code=True)
chatglm_version = tensorrt_llm_config['chatglm_version']
rope_ratio = tensorrt_llm_config.get('rope_ratio', 1.0)
if chatglm_version == 'chatglm2':
if rope_ratio > 1:
rotary_embedding_scaling = {
'type': 'linear',
'factor': rope_ratio
}
elif chatglm_version == 'chatglm3':
rotary_base *= rope_ratio
with open(f"{export_path}/config.json", "r") as f:
tensorrt_llm_config = json.load(f)
tensorrt_llm_config['rotary_base'] = rotary_base
tensorrt_llm_config['rotary_scaling'] = rotary_embedding_scaling
tensorrt_llm_config['rotary_pct'] = 0.5
with open(f"{export_path}/config.json", "w") as f:
json.dump(tensorrt_llm_config, f, indent=4)
# Workaround for combining medusa head
# TODO: move these integration into modelopt to avoid redundant reading and writing
if medusa_model_dir is not None:
combine_medusa_weight(tp_size, pp_size, export_path,
num_medusa_heads, num_medusa_layers,
max_draft_len, medusa_hidden_act,
medusa_model_dir, quant_medusa_head)
end_time = time.time()
logger.info(
"Quantized model exported to {} \nTotal time used {:.2f} s.".format(
export_path, end_time - start_time))
# Need to delete the model and release memory explicitly;
# otherwise torch may retain its GPU memory until a delayed GC running,
# which reduces the available GPU memory for subsequent stages.
del model
release_gc()
def unwrap_model(model, module_instances=None):
# Reference: https://github.com/NVIDIA/Megatron-LM/blob/core_r0.8.0/megatron/training/utils.py
from megatron.core import DistributedDataParallel as DDP
from megatron.core.transformer.module import Float16Module
if module_instances is None:
module_instances = (DDP, Float16Module)
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def get_nemo_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
batch_size=64,
calib_size=512,
max_sequence_length=512):
if dataset_name_or_dir == "pileval":
dataset = load_dataset(
"json",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
text_column = "text"
elif "wikitext" in dataset_name_or_dir:
dataset = load_dataset(dataset_name_or_dir,
"wikitext-103-v1",
split="train")
text_column = "text"
elif "cnn_dailymail" in dataset_name_or_dir:
dataset = load_dataset(dataset_name_or_dir, name="3.0.0", split="train")
text_column = "article"
elif os.path.isdir(dataset_name_or_dir):
logger.info(
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
"assuming the calibration data are in the train split and text column."
)
dataset = load_dataset(dataset_name_or_dir, split="train")
text_column = "text"
else:
raise NotImplementedError(
f"Unsupported dataset name or local repo directory: {dataset_name_or_dir}."
)
calib_size = max(min(len(dataset), calib_size), batch_size)
for i in range(calib_size // batch_size):
batch = dataset[i * batch_size:(i + 1) * batch_size][text_column]
for j in range(len(batch)):
batch[j] = batch[j][:max_sequence_length]
yield batch
def quantize_nemo_and_export(*, nemo_ckpt_path, decoder_type, calib_dataset,
calib_tp_size, calib_pp_size, dtype, qformat,
kv_cache_dtype, calib_size, batch_size,
calib_max_seq_length, awq_block_size, output_dir,
tp_size, pp_size, seed):
try:
import modelopt # noqa
except ImportError as e:
logger.error(
"Failed to import modelopt, pls check the modelopt installation. Currently it is known to be unsupported on Windows OS"
)
raise e
import modelopt.torch.quantization as mtq
from megatron.core import parallel_state
from megatron.core.transformer.module import Float16Module
from modelopt.torch.export import export_tensorrt_llm_checkpoint
from modelopt.torch.utils import print_rank_0
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import \
MegatronGPTModel
from nemo.collections.nlp.modules.common.text_generation_strategy import \
GPTModelTextGenerationStrategy
from nemo.collections.nlp.parts.nlp_overrides import (
NLPDDPStrategy, NLPSaveRestoreConnector)
from nemo.utils.model_utils import load_config, save_artifacts
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for the inference.")
random.seed(seed)
np.random.seed(seed)
model_cfg = load_config(nemo_ckpt_path)
# dtype is used for non-quantized layers
supported_dtype = ["auto", "float16", "bfloat16"]
assert dtype in supported_dtype, f"{dtype} not supported. Supported dtypes are {supported_dtype}"
if dtype == 'auto':
dtype = model_cfg.get('precision', None)
if dtype is None:
dtype = 'float16'
elif 'bf16' in dtype or 'bfloat16' in dtype:
dtype = 'bfloat16'
else:
dtype = 'float16'
logger.info(f"Specified dtype 'auto'; inferred dtype {dtype!r}.")
torch_dtype = getattr(torch, dtype)
with open_dict(model_cfg):
model_cfg.activations_checkpoint_method = None
model_cfg.activations_checkpoint_granularity = None
model_cfg.tensor_model_parallel_size = calib_tp_size
model_cfg.pipeline_model_parallel_size = calib_pp_size
model_cfg.sequence_parallel = False
# Only custom modelopt spec is supported for PTQ: this custom spec is largely based on local Megatron-LM
# layer definitions to avoid Transformer Engine implementations that are currently not supported.
model_cfg.name = "modelopt"
# trainer required for restoring model parallel models
trainer_config = {
'devices': calib_tp_size * calib_pp_size,
'num_nodes': 1,
'accelerator': 'gpu',
'logger': False,
'precision': model_cfg.precision,
'enable_checkpointing': False,
}
trainer = Trainer(strategy=NLPDDPStrategy(), **trainer_config)
connector = NLPSaveRestoreConnector()
model = MegatronGPTModel.restore_from(
restore_path=nemo_ckpt_path,
trainer=trainer,
override_config_path=model_cfg,
save_restore_connector=connector,
)
model.freeze()
print_rank_0(model)
# Have to turn off activations_checkpoint_method for inference
try:
model.model.module.language_model.encoder.activations_checkpoint_method = None
except AttributeError:
pass
# Check whether the DDP is initialized
if parallel_state.is_unitialized():
def dummy():
return
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
inference_config = {
'greedy': False,
'top_k': 0,
'top_p': 0.9,
'temperature': 1.0,
'add_BOS': True,
'tokens_to_generate': 30,
'all_probs': False,
'repetition_penalty': 1.2,
'min_tokens_to_generate': 0,
'compute_logprob': False,
'batch_size': batch_size,
'max_context_length': calib_max_seq_length,
'strategy': GPTModelTextGenerationStrategy(model),
}
model.set_inference_config(inference_config)
if qformat in ["full_prec", "int8_wo", "int4_wo"
] and kv_cache_dtype is None:
print_rank_0(f"No quantization applied, export {dtype} model")
else:
if "awq" in qformat:
if calib_size > 32:
print_rank_0(
"AWQ calibration could take longer with calib_size ="
f" {calib_size}, Using calib_size=32 instead")
calib_size = 32
print_rank_0(
"\nAWQ calibration could take longer than other calibration methods. Please"
" increase the batch size to speed up the calibration process. Batch size can be"
" set by adding the argument inference.batch_size=<batch_size> to the command"
" line.\n")
dataloader = get_nemo_calib_dataloader(
dataset_name_or_dir=calib_dataset,
batch_size=batch_size,
calib_size=calib_size,
max_sequence_length=calib_max_seq_length,
)
# =================== Start Quantization ====================
if qformat in quant_cfg_choices():
quant_cfg = quant_cfg_choices()[qformat]
else:
raise ValueError(f"Unsupported quantization format: {qformat}")
if "awq" in qformat:
quant_cfg = copy.deepcopy(quant_cfg_choices()[qformat])
weight_quantizer = quant_cfg["quant_cfg"][
"*weight_quantizer"] # type: ignore
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
weight_quantizer["block_sizes"][-1] = awq_block_size
if kv_cache_dtype is not None:
if kv_cache_dtype == "fp8":
for value in KV_CACHE_CFG.values():
value.update({"num_bits": (4, 3)}) # type: ignore
quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
print_rank_0(quant_cfg)
# Always turn on FP8 kv cache to save memory footprint.
# For int8_sq, we use int8 kv cache.
# TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for nemotron.
# quant_cfg["quant_cfg"]["*output_quantizer"] = { # type: ignore[index]
# "num_bits": 8 if args.qformat == "int8_sq" else (4, 3),
# "axis": None,
# "enable": args.decoder_type != "gptnext",
# }
dataloader = [data for data in dataloader]
def forward_loop(model):
for i, batch in enumerate(dataloader):
print_rank_0(f"Calibrating batch {i}")
model.predict_step(batch, i)
start_time = time.time()
model = mtq.quantize(model, quant_cfg,
forward_loop) # type: ignore[arg-type]
end_time = time.time()
tot_time = end_time - start_time
tput = calib_size / tot_time
print_rank_0(
f"Quantization done. Total time used {tot_time}s. Throughput {tput} samples/s"
)
# =================== End Quantization ======================
if decoder_type == "gptnext":
# We found squared_relu may have an under-calibration problem.
# Clamp the scaling_factor with a min threshold to avoid under-calibration.
maxbound = 0
if qformat == "fp8":
maxbound = 448
elif qformat == "int8_sq":
maxbound = 127
model = mtq.postprocess_amax(
model, "*input_quantizer",
lambda amax: torch.clamp(amax, min=0.01 * maxbound))
if torch.distributed.get_rank() == 0:
mtq.print_quant_summary(model)
if model_cfg.megatron_amp_O2:
model.model = unwrap_model(model.model, Float16Module)
start_time = time.time()
export_tensorrt_llm_checkpoint(
model,
decoder_type,
torch_dtype,
export_dir=output_dir,
inference_tensor_parallel=tp_size,
inference_pipeline_parallel=pp_size,
)
end_time = time.time()
print_rank_0(
f"Model config exported to: {output_dir}. Total time used {end_time - start_time}s"
)
if torch.distributed.get_rank() == 0:
save_artifacts(model, output_dir, use_abspath=True)
# Need to delete the model and release memory explicitly;
# otherwise torch may retain its GPU memory until a delayed GC running,
# which reduces the available GPU memory for subsequent stages.
del model
release_gc()