Advanced Runtime Features#
Overview#
The TensorRT Edge-LLM C++ Runtime provides several advanced features that enable sophisticated inference capabilities, from CUDA graph optimization to dynamic LoRA adapter switching. These features are designed to maximize performance, flexibility, and efficiency in production deployments.
CUDA Graph Optimization#
The runtime provides sophisticated CUDA graph capture and execution for the generation phase (standard mode only):
Graph Capture Process:
Pre-execution: TensorRT engine is executed once before graph capture to avoid errors
State Simulation: KV-Cache state is simulated to match post-prefill conditions
Input Validation: Tensor shapes and configurations are validated before capture
Graph Creation: CUDA stream capture records the entire generation step execution
Hash-based Storage: Graphs are stored with hash keys based on input shapes and LoRA configurations
Graph Execution:
Hash Lookup: Input configurations are hashed to find matching pre-captured graphs
Direct Launch: Matching graphs are launched directly via
cudaGraphLaunchFallback Execution: Non-matching configurations fall back to standard TensorRT execution
Multi-configuration Support: Separate graphs captured for different batch sizes and LoRA adapters
Performance Benefits:
Reduced Kernel Launch Overhead: CUDA graphs can reduce kernel launch overhead by 10-30%
Consistent Latency: Graph execution provides more predictable per-token latency
Optimized Memory Access: Graph replay optimizes GPU memory access patterns
Limitations:
Standard Runtime Only: CUDA graphs are not supported in EAGLE SpecDecode mode
Configuration-Specific: Separate graphs required for different batch sizes and LoRA configurations
Memory Overhead: Each captured graph requires additional GPU memory
LoRA (Low-Rank Adaptation) Support#
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that adapts large language models by learning low-rank decomposition matrices rather than updating all model weights. Instead of modifying the original model parameters, LoRA adds small trainable rank decomposition matrices to existing layers, enabling task-specific customization with minimal memory overhead and computational cost.
The runtime provides comprehensive dynamic LoRA adapter management:
Adapter Management:
SafeTensors Loading: LoRA weights loaded from industry-standard SafeTensors format
Dynamic Registry: Multiple adapters managed with name-based identification
Rank Validation: Adapter ranks validated against engine’s maximum supported rank
GPU Memory Storage: Efficient GPU memory allocation for adapter weights
Runtime Switching:
Zero-Rank Fallback: Adapters disabled by setting rank dimensions to 0
Tensor Binding: LoRA tensors bound to both prefill and generation execution contexts
CUDA Graph Integration: Separate graph capture for each LoRA configuration
State Preservation: KV-cache and model state maintained during adapter switches
API Interface:
addLoraWeights(name, filePath): Loads LoRA weights from SafeTensors filesswitchLoraWeights(name): Switches active adapter or disables with empty namegetAvailableLoraWeights(): Returns list of loaded adaptersgetActiveLoraWeightsName(): Returns currently active adapter name
Use Cases:
Domain Adaptation: Switch between medical, legal, technical, or other specialized domains
Multi-tenant Serving: Serve different customized models to different users/customers
A/B Testing: Compare performance of different fine-tuned variants
Task-Specific Optimization: Use specialized adapters for different task types
Batch Processing#
The runtime supports efficient batch processing throughout the inference pipeline:
Memory Management:
Unified Allocation: Batch-aware memory allocation reserves space for maximum batch size
Tensor Layouts: Consistent tensor layouts support concurrent sequence processing
Dynamic Padding: Input sequences padded to batch maximum length for parallel processing
Execution Flow:
Parallel Prefill: All batch sequences processed simultaneously during prefill
Concurrent Generation: Tokens generated for all active sequences in each iteration
Individual Tracking: Each sequence maintains independent completion state
Dynamic Removal: Completed sequences removed from batch as they finish
Performance Benefits:
Increased Throughput: Process multiple requests simultaneously
GPU Utilization: Better GPU utilization through parallel processing
Amortized Overhead: Fixed costs amortized across multiple sequences
Considerations:
Memory Usage: Batch size limited by available GPU memory
Latency Trade-off: Higher batch sizes may increase per-request latency
System Prompt KV-Cache Optimization#
The runtime implements intelligent caching for repeated system prompts:
Cache Management:
Hash-based Storage: System prompts cached using combined hash of prompt and LoRA adapter
KV-Cache Persistence: Key-value cache content saved and reused for identical system prompts
Memory Efficiency: Avoids recomputing prefill for repeated system prompt patterns
Automatic Reuse: Cache automatically detected and loaded for matching prompts
Performance Benefits:
Reduced Latency: Eliminates prefill computation for cached system prompts
Memory Optimization: Efficient storage of frequently used prompt states
Batch Compatibility: Cache reuse works seamlessly with batch processing
Use Cases:
Chatbots with Fixed Instructions: Cache common system instructions
API Services: Reuse system prompts across multiple user requests
Multi-turn Conversations: Cache conversation context across turns
Vocabulary Reduction#
Vocabulary reduction optimizes model size and inference performance by selecting a subset of the most relevant tokens for domain-specific deployments.
Overview:
Token Selection: Users create a vocabulary mapping using
tensorrt-edgellm-reduce-vocab(seetensorrt_edgellm/vocab_reduction/vocab_reduction.pyfor implementation)Automatic Runtime Support: Runtime transparently uses
vocab_map.safetensorswhen present in the engine directoryPerformance Gains: Smaller LM head layers, faster inference, reduced memory footprint
Methods: input_aware (recommended, analyzes usage patterns) or frequency (token frequency-based)
Note: Vocabulary reduction is task-dependent, so these provided methods are only reference methods. Users should create the token map with their appropriate methods and sample data to ensure proper coverage of expected tokens
Multimodal Processing#
The runtime provides comprehensive support for Vision Language Models (VLMs):
Vision Processing Pipeline:
Image Preprocessing: Normalization, resizing, and tensor conversion for vision inputs
ViT Execution: Vision Transformer models process images to generate embeddings
Token Integration: Vision embeddings integrated with text tokens before LLM processing
Dynamic Resolution: Support for variable image resolutions and patch counts
Supported Architectures:
Qwen-VL Series: Qwen2-VL, Qwen2.5-VL, and Qwen3-VL with dynamic image patches and window attention
InternVL Series: InternVL3 with 0.5 downsampling ratio and fixed image size processing
Phi-4-Multimodal: LoRA based vision-language model support
Rotary Position Encoding: Advanced positional encoding for multimodal sequences
Processing Flow:
Image Loading: Images loaded from file paths specified in request
Vision Encoding: Images processed through ViT to generate vision embeddings
Token Merging: Vision embeddings merged with text token embeddings
LLM Processing: Combined embeddings processed through language model
Text Generation: Output text generated based on multimodal context
Performance Characteristics:
Prefill Impact: Vision processing adds to prefill phase latency
Memory Usage: Vision embeddings increase KV-cache memory requirements
Batch Processing: Multiple images can be processed in batch mode
Usage Examples#
Using CUDA Graphs#
// CUDA graphs are automatically enabled for standard runtime
LLMInferenceRuntime runtime(engineDir);
// First inference captures the graph
auto response1 = runtime.handleRequest(request);
// Subsequent inferences with same configuration use captured graph
auto response2 = runtime.handleRequest(request);
Dynamic LoRA Switching#
// Load multiple LoRA adapters
runtime.addLoraWeights("medical", "lora_weights/medical_adapter.safetensors");
runtime.addLoraWeights("legal", "lora_weights/legal_adapter.safetensors");
// Switch between adapters dynamically
runtime.switchLoraWeights("medical");
auto medical_response = runtime.handleRequest(medical_request);
runtime.switchLoraWeights("legal");
auto legal_response = runtime.handleRequest(legal_request);
// Disable LoRA to use base model
runtime.switchLoraWeights("");
auto base_response = runtime.handleRequest(base_request);
Batch Processing#
// Prepare multiple requests
std::vector<InferenceRequest> batch_requests = {
{.inputText = "Question 1", .maxLength = 100},
{.inputText = "Question 2", .maxLength = 100},
{.inputText = "Question 3", .maxLength = 100}
};
// Process batch (implementation depends on your integration)
for (const auto& request : batch_requests) {
auto response = runtime.handleRequest(request);
// Process response...
}
System Prompt Caching#
// First request with system prompt - triggers cache creation
InferenceRequest request1;
request1.systemPrompt = "You are a helpful medical assistant.";
request1.inputText = "What is aspirin?";
auto response1 = runtime.handleRequest(request1);
// Second request with same system prompt - reuses cache
InferenceRequest request2;
request2.systemPrompt = "You are a helpful medical assistant.";
request2.inputText = "What is ibuprofen?";
auto response2 = runtime.handleRequest(request2); // Faster prefill!
Multimodal VLM Inference#
// Initialize runtime with visual engine
LLMInferenceRuntime runtime(engineDir, visualEngineDir);
// Single image inference
InferenceRequest request1;
request1.inputText = "Describe this image in detail.";
request1.imagePaths = {"photo.jpg"};
auto response1 = runtime.handleRequest(request1);
// Multiple images
InferenceRequest request2;
request2.inputText = "Compare these two images.";
request2.imagePaths = {"image1.jpg", "image2.jpg"};
auto response2 = runtime.handleRequest(request2);
Vocabulary Reduction#
# Step 1: Reduce vocabulary using input-aware analysis
tensorrt-edgellm-reduce-vocab \
--model_dir Qwen/Qwen3-4B-Instruct-2507 \
--output_dir reduced_vocab \
--reduced_vocab_size 16384 \
--method input_aware \
--max_samples 100000
# Optional: Add --d2t_path for EAGLE speculative decoding models
# --d2t_path onnx_models/qwen3-4b_eagle_draft/d2t.safetensors
# Step 2: Export model with reduced vocabulary
tensorrt-edgellm-export-llm \
--model_dir quantized/qwen3-4b \
--output_dir llm_onnx \
--reduced_vocab_dir reduced_vocab/
# Step 3: Build TensorRT engine (same as standard workflow)
./build/examples/llm/llm_build \
--onnxDir llm_onnx \
--engineDir engines/qwen3-4b \
--maxBatchSize 1
# Step 4: Run inference (same as standard workflow)
# Runtime automatically uses vocab_map.safetensors when present
./build/examples/llm/llm_inference \
--engineDir engines/qwen3-4b \
--inputFile input.json \
--outputFile output.json
Output: The reduced_vocab/ directory will contain vocab_map.safetensors with the vocabulary mapping, which the runtime automatically applies during inference.
Next Steps#
Try Examples: Run the Examples to see advanced features in action
Benchmark Performance: Measure the impact of CUDA graphs, LoRA, and batch processing
Integrate into Application: Use advanced features to optimize your deployment
Review API Documentation: Refer to detailed API docs in
cpp/runtime/headers
Additional Resources#
Runtime API: Refer to the
cpp/runtime/directoryExample Applications: Refer to
examples/llm/andexamples/multimodal/Architecture Overview: Refer to Overview
LoRA Support: Refer to LoRA documentation in Python Export Pipeline