# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ctypes
import platform
from collections import OrderedDict
from dataclasses import asdict, dataclass, field, fields
from enum import IntEnum
from pathlib import Path
from textwrap import dedent
from typing import List, Optional, Tuple
import tensorrt as trt
from .._ipc_utils import IpcMemory, can_access_peer
from ..logger import logger
from ..mapping import Mapping
TRT_LLM_PLUGIN_NAMESPACE = 'tensorrt_llm'
def plugin_lib_path() -> str:
project_dir = Path(__file__).parent.parent.absolute()
dyn_lib = "libnvinfer_plugin_tensorrt_llm.so" if platform.system(
) != "Windows" else "nvinfer_plugin_tensorrt_llm.dll"
return str(project_dir.joinpath("libs", dyn_lib))
def _load_plugin_lib():
on_windows = platform.system() == "Windows"
winmode = 0 if on_windows else None
handle = ctypes.CDLL(plugin_lib_path(),
mode=ctypes.RTLD_GLOBAL,
winmode=winmode)
try:
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as err:
raise ImportError('TensorRT-LLM Plugin is unavailable') from err
try:
assert handle.initTrtLlmPlugins(
None, TRT_LLM_PLUGIN_NAMESPACE.encode('utf-8'))
except OSError as e:
windows_err = """
The error above may be caused by an outdated Microsoft Visual C++ Redistributable Version.
Please install the latest MSVC from the link below and re-launch.
https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170#latest-microsoft-visual-c-redistributable-version
"""
err_msg = dedent(windows_err if on_windows else "Unknown error")
raise RuntimeError(err_msg) from e
except Exception as e:
raise e
class ContextFMHAType(IntEnum):
disabled = 0
# FP16 I/O, FP16 Accumulation
enabled = 1
# FP16 I/O, FP32 Accumulation
enabled_with_fp32_acc = 2
DEFAULT_PLUGIN_DTYPE_OPTIONS = [
"auto", "float16", "float32", "bfloat16", "int32", None
]
PLUGIN_DTYPE_OPTIONS_MAP = {
"gemm_swiglu_plugin": ["fp8", None],
"gemm_plugin":
["auto", "float16", "float32", "bfloat16", "int32", "fp8", None],
"low_latency_gemm_plugin": ["fp8", None],
}
def _make_plugin_property(field_name: str, field_type: type):
def bind(field_name):
storage_name = f'_{field_name}'
@property
def prop(self):
field_value = getattr(self, storage_name)
if field_name != 'dtype' and field_value == 'auto':
return self.dtype
else:
return field_value
@prop.setter
def prop(self, value):
if field_type is bool:
assert isinstance(value, bool), \
f"Plugin {field_name} expects {field_type}, got {type(value)}"
elif field_type in (str, Optional[str]):
plugin_dtype_options = DEFAULT_PLUGIN_DTYPE_OPTIONS
if field_name in PLUGIN_DTYPE_OPTIONS_MAP:
plugin_dtype_options = PLUGIN_DTYPE_OPTIONS_MAP[field_name]
assert value in plugin_dtype_options, \
f"Plugin {field_name} expects values in {plugin_dtype_options}, got {value}"
if field_name == 'dtype':
assert value not in ['auto', None], \
"Plugin dtype cannot be auto or None"
setattr(self, storage_name, value)
logger.info(f"Set {field_name} to {value}.")
return prop
return bind(field_name)
class PluginConfigMeta(type):
def __new__(cls, name, bases, attrs):
for storage_name, field_type in attrs['__annotations__'].items():
assert storage_name.startswith('_')
field_name = storage_name.lstrip('_')
attrs[field_name] = _make_plugin_property(field_name, field_type)
return super().__new__(cls, name, bases, attrs)
[docs]
@dataclass(slots=True)
class PluginConfig(metaclass=PluginConfigMeta):
"""The config that manages plugin-related options.
There are two option categories:
* Plugin options (typically with xxx_plugin naming). These options can be assigned with:
* "float16"/"bfloat16"/"float32"/"int32", which means the plugin is enabled with the specified precision; (Some plugins only support limited dtype, i.e., gemm_swiglu_plugin only supports fp8 now)
* "auto", which means the plugin is enabled with the precision of `dtype` field (the `dtype` field must be same to model dtype, i.e., the one in PretrainedConfig);
* None, which means the plugin is disabled.
* Other features. These options can be assigned with boolean:
* True, which means the plugin is enabled;
* False, which means the plugin is disabled.
Note: All the fields should use a prefix "_"; PluginConfigMeta will wrap each field as a property.
This ensures the fields can only be assigned with allowed values.
"""
_dtype: str = field(default="float16", init=False)
# Plugins
_bert_attention_plugin: Optional[str] = field(default="auto", init=False)
_gpt_attention_plugin: Optional[str] = field(default="auto", init=False)
_gemm_plugin: Optional[str] = field(default=None, init=False)
_gemm_swiglu_plugin: Optional[str] = field(default=None, init=False)
_fp8_rowwise_gemm_plugin: Optional[str] = field(default=None, init=False)
_smooth_quant_gemm_plugin: Optional[str] = field(default=None, init=False)
_identity_plugin: Optional[str] = field(default=None, init=False)
_layernorm_quantization_plugin: Optional[str] = field(default=None,
init=False)
_rmsnorm_quantization_plugin: Optional[str] = field(default=None,
init=False)
_nccl_plugin: Optional[str] = field(default="auto", init=False)
_lora_plugin: Optional[str] = field(default=None, init=False)
_weight_only_groupwise_quant_matmul_plugin: Optional[str] = field(
default=None, init=False)
_weight_only_quant_matmul_plugin: Optional[str] = field(default=None,
init=False)
_smooth_quant_plugins: bool = field(
default=True,
init=False) # Always enable smooth quant plugins for external users
_quantize_per_token_plugin: bool = field(default=False, init=False)
_quantize_tensor_plugin: bool = field(default=False, init=False)
_moe_plugin: Optional[str] = field(default="auto", init=False)
_mamba_conv1d_plugin: Optional[str] = field(default="auto", init=False)
_low_latency_gemm_plugin: Optional[str] = field(default=None, init=False)
# Features
_context_fmha: bool = field(default=True, init=False)
_bert_context_fmha_fp32_acc: bool = field(
default=False, init=False) # will use fp16 if disabled
_paged_kv_cache: Optional[bool] = field(default=None, init=False)
_remove_input_padding: bool = field(default=True, init=False)
_reduce_fusion: bool = field(default=False, init=False)
_enable_xqa: bool = field(default=True, init=False)
_tokens_per_block: int = field(default=64, init=False)
_use_paged_context_fmha: bool = field(default=False, init=False)
_use_fp8_context_fmha: bool = field(default=False, init=False)
_multiple_profiles: bool = field(default=False, init=False)
_paged_state: bool = field(default=True, init=False)
_streamingllm: bool = field(default=False, init=False)
_manage_weights: bool = field(default=False, init=False)
_use_fused_mlp: bool = field(default=True, init=False)
_pp_reduce_scatter: bool = field(default=False, init=False)
def update_from_dict(self, config: dict):
for name in config.keys():
if hasattr(self, name):
value_to_be_update = config[name]
if isinstance(getattr(self, name),
bool) or name == 'paged_kv_cache':
if value_to_be_update == "enable":
value_to_be_update = True
elif value_to_be_update == "disable":
value_to_be_update = False
elif value_to_be_update == "disable":
value_to_be_update = None
setattr(self, name, value_to_be_update)
@classmethod
def from_dict(cls, config: dict):
plugin_config = cls()
plugin_config.update_from_dict(config)
return plugin_config
@classmethod
def from_arguments(cls, args: argparse.Namespace):
return cls.from_dict(vars(args))
def to_dict(self):
config = asdict(self)
# Remove prefix "_" of the storage name
config = {key.lstrip('_'): value for key, value in config.items()}
return config
[docs]
def to_legacy_setting(self):
'''Legacy setting means that all of the plugins and features are
disabled, this is needed for the legacy `build.py` script, which will be
migrated to the centralized building script `tensorrt_llm/commands/build.py`.
After the migration is done, this function may or may not be deleted.
'''
for field in fields(self):
# Remove prefix "_" of the storage name
field_name = field.name.lstrip('_')
if field_name == 'dtype':
continue
if field.type in (str, Optional[str]):
setattr(self, field_name, None)
elif field.type == bool or field_name == 'paged_kv_cache':
setattr(self, field_name, False)
@property
def context_fmha_type(self):
if self.bert_context_fmha_fp32_acc:
return ContextFMHAType.enabled_with_fp32_acc
elif self.context_fmha:
return ContextFMHAType.enabled
else:
return ContextFMHAType.disabled
def is_context_fmha_enabled(self):
return self.context_fmha_type != ContextFMHAType.disabled
@context_fmha_type.setter
def context_fmha_type(self, value):
if value == ContextFMHAType.disabled:
self.context_fmha = False
self.bert_context_fmha_fp32_acc = False
else:
self.context_fmha = True
if value == ContextFMHAType.enabled:
self.bert_context_fmha_fp32_acc = False
elif value == ContextFMHAType.enabled_with_fp32_acc:
self.bert_context_fmha_fp32_acc = True
def set_smooth_quant_plugins(self, dtype: str = "auto"):
self.smooth_quant_gemm_plugin = dtype
self.rmsnorm_quantization_plugin = dtype
self.layernorm_quantization_plugin = dtype
self.quantize_per_token_plugin = True
self.quantize_tensor_plugin = True
return self
def set_fp8_rowwise_quant_plugins(self, dtype: str = "auto"):
self.fp8_rowwise_gemm_plugin = dtype
self.rmsnorm_quantization_plugin = dtype
# self.layernorm_quantization_plugin = dtype
self.quantize_per_token_plugin = True
self.quantize_tensor_plugin = True
return self
def set_context_fmha(self, context_fmha_type=ContextFMHAType.enabled):
assert type(context_fmha_type) == ContextFMHAType
self.context_fmha_type = context_fmha_type
return self
def enable_paged_kv_cache(self, tokens_per_block: int = 64):
self.paged_kv_cache = True
self.tokens_per_block = tokens_per_block
return self
def set_nccl_plugin(self, dtype: str = "auto"):
self.nccl_plugin = dtype
init_all_reduce_helper()
return self
cli_plugin_args = [
# Plugins
"bert_attention_plugin",
"gpt_attention_plugin",
"gemm_plugin",
"gemm_swiglu_plugin",
"fp8_rowwise_gemm_plugin",
"lora_plugin",
"moe_plugin",
"mamba_conv1d_plugin",
"nccl_plugin",
"low_latency_gemm_plugin",
# Features
"context_fmha",
"bert_context_fmha_fp32_acc",
"remove_input_padding",
"enable_xqa",
"tokens_per_block",
"use_paged_context_fmha",
"use_fp8_context_fmha",
"multiple_profiles",
"paged_state",
"streamingllm",
"reduce_fusion",
"use_fused_mlp",
"pp_reduce_scatter",
]
def add_plugin_argument(parser: argparse.ArgumentParser):
plugin_config = PluginConfig()
for field in fields(plugin_config):
# Remove prefix "_" of the storage name
field_name = field.name.lstrip('_')
if field_name not in cli_plugin_args:
continue
if field.type in (str, Optional[str]):
plugin_dtype_options = DEFAULT_PLUGIN_DTYPE_OPTIONS
if field_name in PLUGIN_DTYPE_OPTIONS_MAP:
plugin_dtype_options = PLUGIN_DTYPE_OPTIONS_MAP[field_name]
parser.add_argument(
"--" + field_name,
type=str,
default=field.default if field.default else "disable",
choices=[x if x else "disable" for x in plugin_dtype_options],
help=f"Whether to enable/disable ``{field_name}`` and the dtype."
)
elif field.type == bool:
parser.add_argument(
"--" + field_name,
type=str,
default="enable" if field.default else "disable",
choices=["enable", "disable"],
help=f"Whether to enable/disable ``{field_name}``.")
else:
parser.add_argument("--" + field_name,
type=field.type,
default=field.default,
help=f"``{field_name}``.")
return parser
class CustomAllReduceHelper:
"""
Globally visible class to help usage of custom_all_reduce plugin.
Provides the following utilities:
workspace: Tensor
When using CUSTOM or AUTO mode, a tensor containing pointers to memory
visible to all GPUs. It should be 3 pointers per TP rank -
ptr to data buffer, ptr to barriers in, ptr to barriers out.
It must be initialized using IpcMemory class.
Usage:
- Set custom_all_reduce_helper.workspace with the required tensor.
Then, each instance of allreduce will reference that tensor automatically.
"""
POINTERS_PER_RANK = 7
def __init__(self) -> None:
self.workspace: Optional[Tensor] = None
def set_workspace_tensor(self,
mapping: Mapping,
num_profiles: Optional[int] = None):
from ..functional import Tensor
workspace_size = self.POINTERS_PER_RANK * mapping.tp_size + 2
dim_range = None
if num_profiles is not None:
dim_range = OrderedDict([('all_reduce_size',
[workspace_size] * num_profiles)])
self.workspace = Tensor(
name='all_reduce_workspace',
dtype=trt.int64,
shape=[workspace_size],
dim_range=dim_range,
)
@staticmethod
def max_workspace_size_auto(tp_size: int) -> int:
if tp_size <= 2:
return 16_000_000
return 8_000_000
@staticmethod
def allocate_workspace(mapping: Mapping,
size: int) -> Tuple[List[IpcMemory], "torch.tensor"]:
import torch
is_p2p_supported = can_access_peer(mapping)
ipc_buffers_ping = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
ipc_buffers_pong = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
ipc_barriers_in = IpcMemory(
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size * 2,
is_p2p_supported)
ipc_barriers_out = IpcMemory(
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size * 2,
is_p2p_supported)
lamport_buffers_0 = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
lamport_buffers_1 = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
lamport_buffers_2 = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
buffers = [
ipc_buffers_ping, ipc_buffers_pong, ipc_barriers_in,
ipc_barriers_out, lamport_buffers_0, lamport_buffers_1,
lamport_buffers_2
]
return buffers, torch.tensor(
ipc_buffers_ping.serialize() + ipc_buffers_pong.serialize() +
ipc_barriers_in.serialize() + ipc_barriers_out.serialize() +
lamport_buffers_0.serialize() + lamport_buffers_1.serialize() +
lamport_buffers_2.serialize() + [0] + [0],
dtype=torch.int64,
device="cpu")
custom_all_reduce_helper = None
def init_all_reduce_helper():
global custom_all_reduce_helper
custom_all_reduce_helper = CustomAllReduceHelper()
def current_all_reduce_helper():
global custom_all_reduce_helper
assert custom_all_reduce_helper is not None, "You must call `init_all_reduce_helper` first"
return custom_all_reduce_helper