Source code for tensorrt_edgellm.onnx_export.lora

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shutil
import time
from collections import namedtuple
from typing import Tuple

import numpy as np
import onnx
import onnx_graphsurgeon as gs
import torch
from safetensors import safe_open
from safetensors.torch import save_file

GEMMInfo = namedtuple("GEMMInfo", ["input", "output", "name", "weight_shape"])


def _find_matmul_node(quantize_linear_node: gs.Node) -> gs.Node:
    """
    Find the MatMul node after the quantize linear node. Usually it is 2-3 levels deep.
    """
    node = quantize_linear_node
    max_depth = 5
    depth = 0
    while node.op != "MatMul" and depth < max_depth:
        node = node.outputs[0].outputs[0]
        depth += 1
    if depth >= max_depth:
        raise ValueError(
            f"MatMul node not found after {max_depth} levels of quantization for {quantize_linear_node.name}. Please check the ONNX graph."
        )
    return node


def _find_weight_shape(gemm_node: gs.Node) -> gs.Constant:
    """
    Find the weight shape of the GEMM node. The weight shape is not intuitive because of the quantization and transpose nodes.
    """
    # Weights are always on the second of a GEMM node.
    node = gemm_node.inputs[1].inputs[0]
    max_depth = 5
    depth = 0
    num_transpose = 0
    while node.op != "DequantizeLinear" and node.op != "TRT_MXFP8DequantizeLinear" and depth < max_depth:
        if node.op == "Transpose":
            num_transpose += 1
        node = node.inputs[0].inputs[0]
        depth += 1
    if depth >= max_depth:
        raise ValueError(
            f"DequantizeLinear node not found above {max_depth} levels of GEMM for {gemm_node.name}. Please check the ONNX graph."
        )
    weight = node.inputs[0]
    if num_transpose % 2 == 1:
        weight_shape = (weight.shape[1], weight.shape[0])
    else:
        weight_shape = weight.shape
    return weight_shape


def _match_fp8_gemm(graph: gs.Graph):
    """
    Match FP8 GEMM nodes in the graph.
    """
    fp8_gemm_infos = []
    fp8_quantize_linear_nodes = [
        node for node in graph.nodes if node.op == "TRT_FP8QuantizeLinear"
    ]
    for node in fp8_quantize_linear_nodes:
        if node.inputs[0].inputs[0].op == "Cast":
            input_node = node.inputs[0].inputs[0].inputs[0]
        else:
            input_node = node.inputs[0]
        matmul_node = _find_matmul_node(node)
        weight_shape = _find_weight_shape(matmul_node)
        gemm_info = GEMMInfo(input=input_node,
                             output=matmul_node.outputs[0],
                             name=matmul_node.name,
                             weight_shape=weight_shape)
        fp8_gemm_infos.append(gemm_info)
    return fp8_gemm_infos


def _match_nvfp4_gemm(graph: gs.Graph):
    """
    Match NVFP4 GEMM nodes in the graph.
    """
    nvfp4_gemm_infos = []
    nvfp4_quantize_linear_nodes = [
        node for node in graph.nodes if node.op == "TRT_FP4DynamicQuantize"
    ]
    for node in nvfp4_quantize_linear_nodes:
        input_node = node.inputs[0]
        matmul_node = _find_matmul_node(node)
        weight_shape = _find_weight_shape(matmul_node)
        gemm_info = GEMMInfo(input=input_node,
                             output=matmul_node.outputs[0],
                             name=matmul_node.name,
                             weight_shape=weight_shape)
        nvfp4_gemm_infos.append(gemm_info)
    return nvfp4_gemm_infos


def _match_int4_gemm(graph: gs.Graph):
    """
    Match INT4 GEMM nodes in the graph.
    """
    int4_gemm_infos = []
    int4_gemm_nodes = [
        node for node in graph.nodes if node.op == "Int4GroupwiseGemmPlugin"
    ]
    for node in int4_gemm_nodes:
        # For AWQ, the input is smoothed by a Mul and a Cast node.
        if node.inputs[0].inputs[
                0].op == "Cast" and "input_quantizer" in node.inputs[0].inputs[
                    0].inputs[0].name:
            cast_node = node.inputs[0].inputs[0]
            mul_node = cast_node.inputs[0].inputs[0]
            input_node = mul_node.inputs[0]
        # For GPTQ, no smoothing is applied.
        else:
            input_node = node.inputs[0]
        weight_shape = (node.attrs["gemm_k"], node.attrs["gemm_n"])
        gemm_info = GEMMInfo(input=input_node,
                             output=node.outputs[0],
                             name=node.name,
                             weight_shape=weight_shape)
        int4_gemm_infos.append(gemm_info)
    return int4_gemm_infos


def _match_mxfp8_gemm(graph: gs.Graph):
    """
    Match MXFP8 GEMM nodes in the graph.
    """
    mxfp8_gemm_infos = []
    mxfp8_quantize_linear_nodes = [
        node for node in graph.nodes if node.op == "TRT_MXFP8DynamicQuantize"
    ]
    for node in mxfp8_quantize_linear_nodes:
        input_node = node.inputs[0]
        matmul_node = _find_matmul_node(node)
        weight_shape = _find_weight_shape(matmul_node)
        gemm_info = GEMMInfo(input=input_node,
                             output=matmul_node.outputs[0],
                             name=matmul_node.name,
                             weight_shape=weight_shape)
        mxfp8_gemm_infos.append(gemm_info)
    return mxfp8_gemm_infos


def _match_fp16_gemm(graph: gs.Graph):
    """
    Match FP16 GEMM nodes in the graph.
    """
    fp16_gemm_infos = []
    fp16_gemm_nodes = [node for node in graph.nodes if node.op == "MatMul"]
    for node in fp16_gemm_nodes:
        input_node = node.inputs[0]
        if not isinstance(node.inputs[1], gs.Constant):
            continue
        weight_shape = node.inputs[1].shape
        gemm_info = GEMMInfo(input=input_node,
                             output=node.outputs[0],
                             name=node.name,
                             weight_shape=weight_shape)
        fp16_gemm_infos.append(gemm_info)
    return fp16_gemm_infos


def _match_gemm_infos(graph: gs.Graph):
    """
    Match all GEMM nodes in the graph.
    """
    gemm_infos = []
    gemm_infos.extend(_match_fp8_gemm(graph))
    gemm_infos.extend(_match_nvfp4_gemm(graph))
    gemm_infos.extend(_match_int4_gemm(graph))
    gemm_infos.extend(_match_mxfp8_gemm(graph))
    gemm_infos.extend(_match_fp16_gemm(graph))
    return gemm_infos


# Helper functions for LoRA weight processing
def _load_adapter_config(config_path: str) -> Tuple[float, int]:
    """
    Load adapter config and return lora_alpha and r values.
    
    Args:
        config_path (str): Path to adapter_config.json
        
    Returns:
        Tuple[float, int]: (lora_alpha, r)
    """
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config['lora_alpha'], config['r']


def _process_tensor_name(key: str) -> str:
    """
    Process tensor name by removing 'base_model.model' prefix and ensuring it starts with 'model'.
    
    Args:
        key (str): Original tensor name
        
    Returns:
        str: Processed tensor name
    """
    if key.startswith('base_model.model.'):
        key = key[len('base_model.model.'):]
    if not key.startswith('model.'):
        key = 'model.' + key
    return key


def _should_keep_tensor(key: str) -> bool:
    """
    Check if tensor should be kept (exclude norm and lm_head tensors).
    
    Args:
        key (str): Tensor name
        
    Returns:
        bool: True if tensor should be kept
    """
    return 'norm' not in key and 'lm_head' not in key


def _process_tensor(tensor: torch.Tensor, key: str, lora_alpha: float,
                    r: int) -> torch.Tensor:
    """
    Process tensor according to requirements:
    1. Convert bf16 to fp16
    2. Multiply lora_B.weight by lora_alpha/r
    3. Ensure correct shapes for lora_A and lora_B
    
    Args:
        tensor (torch.Tensor): Input tensor
        key (str): Tensor name
        lora_alpha (float): LoRA alpha value
        r (int): LoRA rank
        
    Returns:
        torch.Tensor: Processed tensor
    """

    # Handle lora_B.weight multiplication
    if 'lora_B.weight' in key:
        tensor = tensor * (lora_alpha / r)

    # Ensure correct shapes
    if 'lora_A.weight' in key:
        if tensor.shape[-1] != r:
            tensor = tensor.transpose(-2, -1)
    elif 'lora_B.weight' in key:
        if tensor.shape[0] != r:
            tensor = tensor.transpose(-2, -1)

    # Convert to fp16
    tensor = tensor.to(torch.float16).contiguous()

    return tensor


# Main functions for external use
[docs] def insert_lora_and_save(onnx_dir: str): """ Insert LoRA patterns into ONNX models. Args: onnx_dir (str): Directory containing the ONNX model (model.onnx and config.json) output_dir (str): Directory to save the modified ONNX model mode (str): LoRA insertion mode: 'dynamic' (default) or 'static' lora_weights_dir (str): Directory containing LoRA weights (required for static mode) """ start_time = time.time() # Load ONNX model onnx_model_path = os.path.join(onnx_dir, "model.onnx") print(f"Loading original ONNX model from {onnx_model_path}...") # The LoRA model will share the same data as the base model onnx_model = onnx.load(onnx_model_path, load_external_data=False) graph = gs.import_onnx(onnx_model) # Insert dynamic LoRA patterns print("Inserting dynamic LoRA patterns...") # Track all GEMM nodes that need LoRA gemm_infos = _match_gemm_infos(graph) # Insert LoRA patterns for each GEMM for gemm_info in gemm_infos: input_tensor = gemm_info.input output_tensor = gemm_info.output gemm_name = gemm_info.name weight_shape = gemm_info.weight_shape k, n = weight_shape if "lm_head" in gemm_name: continue # Create dynamic input tensors for LoRA weights gemm_name_for_lora = gemm_name.replace("/", ".").rsplit(".", 1)[0][1:] lora_a = gs.Variable(f"{gemm_name_for_lora}.lora_A.weight", dtype=np.float16, shape=[k, f"{gemm_name_for_lora}.rank"]) lora_b = gs.Variable(f"{gemm_name_for_lora}.lora_B.weight", dtype=np.float16, shape=[f"{gemm_name_for_lora}.rank", n]) graph.inputs.extend([lora_a, lora_b]) # First MatMul: input @ lora_A lora_mid = gs.Variable(f"{gemm_name}/lora_mid", dtype=np.float16) graph.layer(name=f"{gemm_name}/lora_matmul_A", op="MatMul", inputs=[input_tensor, lora_a], outputs=[lora_mid]) # Second MatMul: (input @ lora_A) @ lora_B lora_out = gs.Variable(f"{gemm_name}/lora_gemm_out", dtype=np.float16) graph.layer(name=f"{gemm_name}/lora_matmul_B", op="MatMul", inputs=[lora_mid, lora_b], outputs=[lora_out]) # Add LoRA output to original output final_output = gs.Variable(f"{gemm_name}/lora_add_output", dtype=np.float16) final_output.outputs = output_tensor.outputs.copy() graph.layer(name=f"{gemm_name}/lora_add", op="Add", inputs=[output_tensor, lora_out], outputs=[final_output]) # Update the output connections for out_node in final_output.outputs: if final_output not in out_node.inputs: out_node.inputs.append(final_output) if output_tensor in out_node.inputs: out_node.inputs.remove(output_tensor) graph.cleanup().toposort().fold_constants().cleanup() # Save modified ONNX model output_model_path = os.path.join(onnx_dir, "lora_model.onnx") print(f"Saving modified ONNX model to {output_model_path}...") modified_onnx_model = gs.export_onnx(graph) onnx.save_model(modified_onnx_model, output_model_path) end_time = time.time() print(f"LoRA model saved to {output_model_path}") print(f"LoRA insertion completed in {end_time - start_time:.2f}s")
[docs] def process_lora_weights_and_save(input_dir: str, output_dir: str): """ Process LoRA weights according to specified requirements. Args: input_dir (str): Directory containing input adapter files output_dir (str): Directory where processed files will be saved """ # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) # Load adapter config config_path = os.path.join(input_dir, 'adapter_config.json') lora_alpha, r = _load_adapter_config(config_path) # Copy config file to output directory shutil.copy2(config_path, os.path.join(output_dir, 'config.json')) # Load safetensors safetensor_path = os.path.join(input_dir, 'adapter_model.safetensors') processed_tensors = {} try: with safe_open(safetensor_path, framework="pt") as f: for key in f.keys(): # Skip unwanted tensors if not _should_keep_tensor(key): continue # Process tensor name new_key = _process_tensor_name(key) # Load and process tensor tensor = f.get_tensor(key) processed_tensor = _process_tensor(tensor, key, lora_alpha, r) # Store processed tensor processed_tensors[new_key] = processed_tensor # Print tensor info print(f"\nTensor: {new_key}") print(f"Shape: {processed_tensor.shape}") print(f"Dtype: {processed_tensor.dtype}") print("-" * 50) # Save processed tensors output_path = os.path.join(output_dir, 'processed_adapter_model.safetensors') save_file(processed_tensors, output_path) print(f"\nProcessed tensors saved to: {output_path}") print( f"Config file copied to: {os.path.join(output_dir, 'config.json')}" ) except Exception as e: print(f"Error processing safetensor file: {e}")