Guides to quantize a customized model from Hugging Face for TensorRT-LLM deployment

ModelOpt can usually quantize PyTorch models from the Hugging Face directly. By default, ModelOpt searches the PyTorch model and replaces the torch nn.Linear module with a quantized linear module.

If the model happens not using the nn.Linear for the linear layers, a customized Hugging Face plugin needs to be implemented to convert the model to use nn.Linear instead.

The following is an example about how a customized Hugging Face model can be supported using modelopt:

The DBRX model is an MoE model with customized MoE linear implementation. The MoE layer in DBRX is implemented as a DbrxExperts module, where the three linear ops (w1, v1 and v2) are represented as nn.Parameter. The linear op is forwarded as a pure matmul op.

As ModelOpt cannot detect these linear ops out-of-the-box, a HugggingFace plugin is implemented as the following:

  1. Define a customized _QuantDbrxExpertGLU as a DynamicModule with the same forward signature.

  2. Rewrite the linear ops (w1, v1 and v2) as a standard nn.Linear op, and re-implement the forward method.

  3. Register the new dynamic _QuantDbrxExperts to replace the DbrxExperts from the modeling_dbrx.py in the transformers library

  4. Try quantize the DBRX model after the plugin is implemented, feel free to follow the llm_ptq example.

  5. TensorRT-LLM is open-sourced. If this customized model is not supported by TensorRT-LLM yet, please modify export_tensorrt_llm_checkpoint or export_hf_checkpoint to export the quantized model for deployment with a customized TensorRT-LLM modeling implementation. Feel free to contact us if further support is needed.

The following code snippet is excerpted from modelopt/torch/quantization/plugins/huggingface.py

from modelopt.torch.opt.dynamic import DynamicModule
from modelopt.torch.quantization.nn import QuantModuleRegistry

if hasattr(transformers.models, "dbrx"):
    # For more information on DbrxExpert, see https://github.com/huggingface/transformers/blame/dcdda5324bcc7a750b5e40e11dd795442204ff27/src/transformers/models/dbrx/modeling_dbrx.py#L756
    class _QuantDbrxExperts(DynamicModule):
        def _setup(self):
            """Modify the DbrxExpert."""
            # No setup is needed for DbrxExpert, we only need to update DbrxExpertGLU
            pass

        # forward method copied from the original dbrx repo - https://github.com/databricks/dbrx/blob/a3200393e678387a6f30f3e903108c650625eb21/model/modeling_dbrx.py#L795
        def forward(
            self,
            x: torch.Tensor,
            weights: torch.Tensor,
            top_weights: torch.Tensor,
            top_experts: torch.LongTensor,
        ) -> torch.Tensor:
            bsz, q_len, hidden_size = x.shape
            x = x.view(-1, hidden_size)
            out = torch.zeros_like(x)

            expert_mask = nn.functional.one_hot(
                top_experts, num_classes=self.moe_num_experts
            ).permute(2, 1, 0)
            for expert_idx in range(0, self.moe_num_experts):
                topk_idx, token_idx = torch.where(expert_mask[expert_idx])
                if token_idx.shape[0] == 0:
                    continue

                token_list = token_idx.tolist()
                topk_list = topk_idx.tolist()

                expert_tokens = x[None, token_list].reshape(-1, hidden_size)
                expert_out = (
                    self.mlp(expert_tokens, expert_idx) * top_weights[token_list, topk_list, None]
                )

                out.index_add_(0, token_idx, expert_out)

            out = out.reshape(bsz, q_len, hidden_size)
            return out

    class _QuantDbrxExpertGLU(DynamicModule):
        def _setup(self):
            """Modify the DbrxExpertGLU by using nn.Linear layers."""
            dtype, device = self.w1.dtype, self.w1.device

            def _copy_weights(modules, weights):
                modules.to(dtype=dtype, device=device)
                for expert_idx, module in enumerate(modules):
                    with torch.no_grad():
                        module.weight.copy_(weights[expert_idx].detach())

            self.w1_linear = nn.ModuleList(
                [
                    nn.Linear(self.hidden_size, self.ffn_hidden_size, bias=False)
                    for _ in range(self.moe_num_experts)
                ]
            )
            _copy_weights(
                self.w1_linear,
                self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size),
            )
            delattr(self, "w1")

            self.v1_linear = nn.ModuleList(
                [
                    nn.Linear(self.hidden_size, self.ffn_hidden_size, bias=False)
                    for _ in range(self.moe_num_experts)
                ]
            )
            _copy_weights(
                self.v1_linear,
                self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size),
            )
            delattr(self, "v1")

            self.w2_linear = nn.ModuleList(
                [
                    nn.Linear(self.ffn_hidden_size, self.hidden_size, bias=False)
                    for _ in range(self.moe_num_experts)
                ]
            )
            _copy_weights(
                self.w2_linear,
                self.w2.view(
                    self.moe_num_experts, self.ffn_hidden_size, self.hidden_size
                ).transpose(1, 2),
            )
            delattr(self, "w2")

        def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
            x1 = self.w1_linear[expert_idx](x)
            x2 = self.v1_linear[expert_idx](x)
            x1 = self.activation_fn(x1)
            x1 = x1 * x2
            return self.w2_linear[expert_idx](x1)

    if transformers.models.dbrx.modeling_dbrx.DbrxExperts not in QuantModuleRegistry:
        QuantModuleRegistry.register(
            {transformers.models.dbrx.modeling_dbrx.DbrxExperts: "hf.DbrxExperts"}
        )(_QuantDbrxExperts)

    if transformers.models.dbrx.modeling_dbrx.DbrxExpertGLU not in QuantModuleRegistry:
        QuantModuleRegistry.register(
            {transformers.models.dbrx.modeling_dbrx.DbrxExpertGLU: "hf.DbrxExpertGLU"}
        )(_QuantDbrxExpertGLU)


def register_dbrx_moe_on_the_fly(model):
    """Register DBRX MoE modules as QUANT_MODULE.

    The MoE class in DBRX is `transformers_modules.modeling_dbrx.DbrxExpertGLU`, which loads dynamically.
    """
    if type(model).__name__ in ["DbrxForCausalLM"]:
        moe_type = type(model.transformer.blocks[0].ffn.experts.mlp)
        # Create a QuantDbrxExpertGLU class on the fly
        if QuantModuleRegistry.get(moe_type) is None:
            QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantDbrxExpertGLU)