Guides to quantize a customized model from Hugging Face for TensorRT-LLM deployment
ModelOpt can usually quantize PyTorch models from the Hugging Face directly. By default, ModelOpt searches the PyTorch model and replaces the torch nn.Linear
module with a quantized linear module.
If the model happens not using the nn.Linear
for the linear layers, a customized Hugging Face plugin needs to be implemented to convert the model to use nn.Linear
instead.
The following is an example about how a customized Hugging Face model can be supported using modelopt:
The DBRX model is an MoE model with customized MoE linear implementation. The MoE layer in DBRX is implemented as a DbrxExperts module, where the three linear ops (w1, v1 and v2) are represented as nn.Parameter
. The linear op is forwarded as a pure matmul
op.
As ModelOpt cannot detect these linear ops out-of-the-box, a HugggingFace plugin is implemented as the following:
Define a customized
_QuantDbrxExpertGLU
as aDynamicModule
with the sameforward
signature.Rewrite the linear ops (w1, v1 and v2) as a standard
nn.Linear
op, and re-implement theforward
method.Register the new dynamic
_QuantDbrxExperts
to replace theDbrxExperts
from the modeling_dbrx.py in thetransformers
libraryTry quantize the DBRX model after the plugin is implemented, feel free to follow the llm_ptq example.
TensorRT-LLM is open-sourced. If this customized model is not supported by TensorRT-LLM yet, please modify
export_tensorrt_llm_checkpoint
orexport_hf_checkpoint
to export the quantized model for deployment with a customized TensorRT-LLM modeling implementation. Feel free to contact us if further support is needed.
The following code snippet is excerpted from modelopt/torch/quantization/plugins/huggingface.py
from modelopt.torch.opt.dynamic import DynamicModule
from modelopt.torch.quantization.nn import QuantModuleRegistry
if hasattr(transformers.models, "dbrx"):
# For more information on DbrxExpert, see https://github.com/huggingface/transformers/blame/dcdda5324bcc7a750b5e40e11dd795442204ff27/src/transformers/models/dbrx/modeling_dbrx.py#L756
class _QuantDbrxExperts(DynamicModule):
def _setup(self):
"""Modify the DbrxExpert."""
# No setup is needed for DbrxExpert, we only need to update DbrxExpertGLU
pass
# forward method copied from the original dbrx repo - https://github.com/databricks/dbrx/blob/a3200393e678387a6f30f3e903108c650625eb21/model/modeling_dbrx.py#L795
def forward(
self,
x: torch.Tensor,
weights: torch.Tensor,
top_weights: torch.Tensor,
top_experts: torch.LongTensor,
) -> torch.Tensor:
bsz, q_len, hidden_size = x.shape
x = x.view(-1, hidden_size)
out = torch.zeros_like(x)
expert_mask = nn.functional.one_hot(
top_experts, num_classes=self.moe_num_experts
).permute(2, 1, 0)
for expert_idx in range(0, self.moe_num_experts):
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
continue
token_list = token_idx.tolist()
topk_list = topk_idx.tolist()
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
expert_out = (
self.mlp(expert_tokens, expert_idx) * top_weights[token_list, topk_list, None]
)
out.index_add_(0, token_idx, expert_out)
out = out.reshape(bsz, q_len, hidden_size)
return out
class _QuantDbrxExpertGLU(DynamicModule):
def _setup(self):
"""Modify the DbrxExpertGLU by using nn.Linear layers."""
dtype, device = self.w1.dtype, self.w1.device
def _copy_weights(modules, weights):
modules.to(dtype=dtype, device=device)
for expert_idx, module in enumerate(modules):
with torch.no_grad():
module.weight.copy_(weights[expert_idx].detach())
self.w1_linear = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.ffn_hidden_size, bias=False)
for _ in range(self.moe_num_experts)
]
)
_copy_weights(
self.w1_linear,
self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size),
)
delattr(self, "w1")
self.v1_linear = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.ffn_hidden_size, bias=False)
for _ in range(self.moe_num_experts)
]
)
_copy_weights(
self.v1_linear,
self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size),
)
delattr(self, "v1")
self.w2_linear = nn.ModuleList(
[
nn.Linear(self.ffn_hidden_size, self.hidden_size, bias=False)
for _ in range(self.moe_num_experts)
]
)
_copy_weights(
self.w2_linear,
self.w2.view(
self.moe_num_experts, self.ffn_hidden_size, self.hidden_size
).transpose(1, 2),
)
delattr(self, "w2")
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
x1 = self.w1_linear[expert_idx](x)
x2 = self.v1_linear[expert_idx](x)
x1 = self.activation_fn(x1)
x1 = x1 * x2
return self.w2_linear[expert_idx](x1)
if transformers.models.dbrx.modeling_dbrx.DbrxExperts not in QuantModuleRegistry:
QuantModuleRegistry.register(
{transformers.models.dbrx.modeling_dbrx.DbrxExperts: "hf.DbrxExperts"}
)(_QuantDbrxExperts)
if transformers.models.dbrx.modeling_dbrx.DbrxExpertGLU not in QuantModuleRegistry:
QuantModuleRegistry.register(
{transformers.models.dbrx.modeling_dbrx.DbrxExpertGLU: "hf.DbrxExpertGLU"}
)(_QuantDbrxExpertGLU)
def register_dbrx_moe_on_the_fly(model):
"""Register DBRX MoE modules as QUANT_MODULE.
The MoE class in DBRX is `transformers_modules.modeling_dbrx.DbrxExpertGLU`, which loads dynamically.
"""
if type(model).__name__ in ["DbrxForCausalLM"]:
moe_type = type(model.transformer.blocks[0].ffn.experts.mlp)
# Create a QuantDbrxExpertGLU class on the fly
if QuantModuleRegistry.get(moe_type) is None:
QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantDbrxExpertGLU)