Source code for tensorrt_llm.quantization.mode

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import IntFlag, auto
from typing import Optional

from strenum import StrEnum

from .._utils import BaseEnumMeta


[docs] class QuantAlgo(StrEnum, metaclass=BaseEnumMeta): W8A16 = auto() W4A16 = auto() W4A16_AWQ = auto() W4A8_AWQ = auto() W8A16_GPTQ = auto() W4A16_GPTQ = auto() W8A8_SQ_PER_CHANNEL = auto() W8A8_SQ_PER_TENSOR_PLUGIN = auto() W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN = auto() W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN = auto() W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN = auto() FP8 = auto() FP8_PER_CHANNEL_PER_TOKEN = auto() INT8 = auto() MIXED_PRECISION = auto() NO_QUANT = auto()
QUANT_ALGO_LIST = list(set(QuantAlgo) - {QuantAlgo.INT8}) KV_CACHE_QUANT_ALGO_LIST = [QuantAlgo.FP8, QuantAlgo.INT8] W8A8_SQ_PLUGIN_LIST = [ QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN, QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN, QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN, QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN, ] MODELOPT_FLOW_QUANTIZATIONS = { QuantAlgo.W4A16_AWQ, QuantAlgo.FP8, QuantAlgo.W8A8_SQ_PER_CHANNEL, QuantAlgo.W4A8_AWQ }
[docs] class QuantMode(IntFlag): # [WARNING] KEEP BELOW DEFINITION IN SYNC WITH cpp/tensorrt_llm/common/quantization.h # The weights are quantized to 4 bits. INT4_WEIGHTS = auto() # The weights are quantized to 8 bits. INT8_WEIGHTS = auto() # The activations are quantized. ACTIVATIONS = auto() # The method uses one scaling factor per channel. It's pre-computed (static) from the weights. PER_CHANNEL = auto() # The method uses one scaling factor per token. It's computed on-the-fly. PER_TOKEN = auto() # The method uses one scaling factor per group. It's pre-computed (static) from the weights. PER_GROUP = auto() # The KV cache is quantized in INT8. INT8_KV_CACHE = auto() # The KV cache is quantized in FP8. FP8_KV_CACHE = auto() # FP8 QDQ FP8_QDQ = auto() # FP8 rowwise FP8_ROWWISE = auto() # The smallest power-of-two that is not used by a flag. Do not call auto() after that line. COUNT = auto() # Bitmask to detect if weights, activations or both are quantized. WEIGHTS_AND_ACTIVATIONS = INT4_WEIGHTS | INT8_WEIGHTS | ACTIVATIONS # The mask of all valid flags. VALID_FLAGS = COUNT - 1 def __deepcopy__(self, memo): return self # All the bits set? You can restrict the test to the bits indicated by "mask". def _all(self, bits, mask=VALID_FLAGS): return (self & mask) == bits # Is one of the bits of the mask set? def _any(self, bits): return (self & bits) != 0 def is_int8_weight_only(self): return self._all(self.INT8_WEIGHTS, self.WEIGHTS_AND_ACTIVATIONS) def is_int4_weight_only(self): return self._all(self.INT4_WEIGHTS, self.WEIGHTS_AND_ACTIVATIONS) def is_weight_only(self): return self.is_int4_weight_only() or self.is_int8_weight_only() def is_int8_weight_only_per_group(self): return self.is_int8_weight_only() and self._any(self.PER_GROUP) def is_int4_weight_only_per_group(self): return self.is_int4_weight_only() and self._any(self.PER_GROUP) def has_act_and_weight_quant(self): return self._all(self.INT8_WEIGHTS | self.ACTIVATIONS, self.WEIGHTS_AND_ACTIVATIONS) def has_act_or_weight_quant(self): return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS | self.ACTIVATIONS) def has_per_token_dynamic_scaling(self): return self._any(self.PER_TOKEN) def has_act_static_scaling(self): return not self.has_per_token_dynamic_scaling( ) and not self.has_fp8_rowwise() def has_per_channel_scaling(self): return self._any(self.PER_CHANNEL) def has_per_group_scaling(self): return self._any(self.PER_GROUP) def has_int8_kv_cache(self): return self._any(self.INT8_KV_CACHE) def has_fp8_kv_cache(self): return self._any(self.FP8_KV_CACHE) def has_kv_cache_quant(self): return self.has_int8_kv_cache() or self.has_fp8_kv_cache() def has_fp8_qdq(self): return self._any(self.FP8_QDQ) def has_fp8_rowwise(self): return self._any(self.FP8_ROWWISE) def has_weight_quant(self): return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS) def has_any_quant(self): return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS | self.ACTIVATIONS | self.INT8_KV_CACHE | self.FP8_KV_CACHE | self.FP8_QDQ | self.FP8_ROWWISE) def set_int8_kv_cache(self): return self | self.INT8_KV_CACHE def set_fp8_kv_cache(self): return self | self.FP8_KV_CACHE def set_fp8_qdq(self): return self | self.FP8_QDQ def set_fp8_rowwise(self): return self | self.FP8_ROWWISE | self.PER_TOKEN | self.PER_CHANNEL @staticmethod def from_description(quantize_weights=False, quantize_activations=False, per_token=False, per_channel=False, per_group=False, use_int4_weights=False, use_int8_kv_cache=False, use_fp8_kv_cache=False, use_fp8_qdq=False, use_fp8_rowwise=False): def raise_error(): raise ValueError(f"Unsupported combination of QuantMode args: " f"{quantize_weights=}, " f"{quantize_activations=}, " f"{per_token=}, " f"{per_channel=}, " f"{per_group=}, " f"{use_int4_weights=}" f"{use_int8_kv_cache=}" f"{use_fp8_kv_cache=}" f"{use_fp8_qdq=}" f"{use_fp8_rowwise=}") # We must quantize weights when we quantize activations. if quantize_activations and not quantize_weights: raise_error() # If we set per_token or per_channel, we must quantize both weights and activations. if (per_token or per_channel) and not (quantize_weights and quantize_activations): raise_error() mode = QuantMode(0) # Do we quantize the weights - if so, do we use INT4 or INT8? if quantize_weights and use_int4_weights: mode = mode | QuantMode.INT4_WEIGHTS elif quantize_weights: mode = mode | QuantMode.INT8_WEIGHTS # Do we quantize the activations? if quantize_activations: mode = mode | QuantMode.ACTIVATIONS # Per-channel/per-token/per-group additional flags. if per_channel: mode = mode | QuantMode.PER_CHANNEL if per_token: mode = mode | QuantMode.PER_TOKEN if per_group: mode = mode | QuantMode.PER_GROUP # Int8 KV cache if use_int8_kv_cache: mode = mode | QuantMode.INT8_KV_CACHE # FP8 KV cache if use_fp8_kv_cache: mode = mode | QuantMode.FP8_KV_CACHE if use_fp8_qdq: mode = mode | QuantMode.FP8_QDQ if use_fp8_rowwise: mode = mode | QuantMode.FP8_ROWWISE | QuantMode.PER_TOKEN | QuantMode.PER_CHANNEL return mode @staticmethod def use_smooth_quant(per_token=False, per_channel=False): return QuantMode.from_description(True, True, per_token, per_channel) @staticmethod def use_weight_only(use_int4_weights=False, per_group=False): return QuantMode.from_description(quantize_weights=True, quantize_activations=False, per_token=False, per_channel=False, per_group=per_group, use_int4_weights=use_int4_weights) @staticmethod def from_quant_algo( quant_algo: Optional[QuantAlgo] = None, kv_cache_quant_algo: Optional[QuantAlgo] = None, ) -> "QuantMode": assert quant_algo is None or quant_algo in QUANT_ALGO_LIST assert kv_cache_quant_algo is None or kv_cache_quant_algo in KV_CACHE_QUANT_ALGO_LIST if quant_algo == QuantAlgo.W8A16: quant_mode = QuantMode.use_weight_only(use_int4_weights=False) elif quant_algo == QuantAlgo.W4A16: quant_mode = QuantMode.use_weight_only(use_int4_weights=True) elif quant_algo == QuantAlgo.W4A16_AWQ: quant_mode = QuantMode.use_weight_only(use_int4_weights=True, per_group=True) elif quant_algo == QuantAlgo.W4A8_AWQ: quant_mode = QuantMode.use_weight_only(use_int4_weights=True, per_group=True) elif quant_algo == QuantAlgo.W4A16_GPTQ: quant_mode = QuantMode.use_weight_only(use_int4_weights=True, per_group=True) elif quant_algo == QuantAlgo.W8A16_GPTQ: quant_mode = QuantMode.use_weight_only(use_int4_weights=False, per_group=True) elif quant_algo == QuantAlgo.W8A8_SQ_PER_CHANNEL: quant_mode = QuantMode.use_smooth_quant(per_token=False, per_channel=True) elif quant_algo == QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN: quant_mode = QuantMode.use_smooth_quant(per_token=False, per_channel=False) elif quant_algo == QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN: quant_mode = QuantMode.use_smooth_quant(per_token=True, per_channel=True) elif quant_algo == QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN: quant_mode = QuantMode.use_smooth_quant(per_token=False, per_channel=True) elif quant_algo == QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN: quant_mode = QuantMode.use_smooth_quant(per_token=True, per_channel=False) elif quant_algo == QuantAlgo.FP8: quant_mode = QuantMode.from_description(use_fp8_qdq=True) elif quant_algo == QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN: quant_mode = QuantMode.from_description(use_fp8_rowwise=True) else: quant_mode = QuantMode(0) if kv_cache_quant_algo == QuantAlgo.INT8: quant_mode = quant_mode.set_int8_kv_cache() elif kv_cache_quant_algo == QuantAlgo.FP8: quant_mode = quant_mode.set_fp8_kv_cache() return quant_mode def to_dict(self): return { 'use_smooth_quant': self.has_act_and_weight_quant(), 'per_channel': self.has_per_channel_scaling(), 'per_token': self.has_per_token_dynamic_scaling(), 'per_group': self.has_per_group_scaling(), 'int8_kv_cache': self.has_int8_kv_cache(), 'enable_fp8': self.has_fp8_qdq(), 'enable_fp8_rowwise': self.has_fp8_rowwise(), 'fp8_kv_cache': self.has_fp8_kv_cache(), 'use_weight_only': self.is_weight_only(), 'weight_only_precision': 'int8' if self.is_int8_weight_only() else 'int4', }