Torch Compile & Piecewise CUDA Graph#
In this guide, we show how to enable torch.compile and Piecewise CUDA Graph in TensorRT LLM. TensorRT LLM uses torch.compile for lightweight vertical fusion and Piecewise CUDA Graph.
Piecewise CUDA Graph is a technique that runs cudagraph-unsupported components (primarily attention) in eager mode while capturing and replaying the supported parts with CUDA Graph to reduce context-phase launch overhead. We implement this on top of torch.compile because partitioning a model between CUDA Graph and eager execution—and managing graphs in pure eager mode—is cumbersome.
Table of Contents#
Usage#
To enable torch.compile and Piecewise CUDA Graph, add the following configuration to extra_config.yml. Typically, the extra_config.yml can be used by adding launching args --extra_llm_api_options extra_config.yml to trtllm-serve or trtllm-bench.
... # Other extra config
torch_compile_config:
capture_num_tokens: '${capture_num_tokens}' # List of num tokens to capture. e.g., [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, ..., 3072]
enable_userbuffers: false
enable_piecewise_cuda_graph: true
Tips for Piecewise CUDA Graph#
Piecewise CUDA Graph & Generation Only CUDA Graph#
Piecewise CUDA Graph only handles context-only and mixed context+generation iterations, while the generation-only CUDA Graph only handles pure generation iterations. Users need to specify the number of tokens to capture for each type of CUDA Graph separately in the extra config. Currently, the default value for capture_num_tokens is [2**i for i in range(8)] + [i for i in range(256, 3073, 256)]. However, this configuration should be tuned based on specific hardware, model, and parallel strategy. For guidance on tuning these values, see the Performance Tuning section below.
cuda_graph_config:
enable_padding: true
max_batch_size: 1024 # Specify max capture batch size for generation only cuda graph. By default, TensorRT LLM will generate a capture list based on it.
torch_compile_config:
capture_num_tokens: '${capture_num_tokens}' # Specify capture_num_tokens for piecewise cuda graph
enable_userbuffers: false
enable_piecewise_cuda_graph: true
Piecewise CUDA Graph Padding#
Padding means that, at runtime, the token count is padded to the next captured token count. Unlike the generation-only CUDA Graph, padding is mandatory for Piecewise CUDA Graph because context-phase token counts vary widely, making it impractical to capture graphs for every possible length.
Performance Tuning#
Piecewise CUDA Graph uses a token-count–based capture strategy: it captures a CUDA graph for each user-specified token count and, at runtime, selects and replays the graph that matches the iteration’s token count(or can be padded to the next captured token count graph) in a single forward pass.
Piecewise CUDA Graph primarily benefit host-bound iterations in the context phase. Within a single iteration, larger token counts reduce exposure to host-side overhead. However, capturing a broader set of token counts increases GPU memory usage and can reduce achievable concurrency. We recommend manually tuning capture_num_tokens to balance latency, memory footprint, and concurrency for your workload.
Guidelines for capture_num_tokens:
Define bounds:
Lower bound: base it on typical context lengths. In low-latency workflows with KV-cache reuse, it can be as small as <10 tokens.
Upper bound: set by hardware and model configuration—choose the largest token count that still provides a measurable benefit from Piecewise CUDA Graph even after padding.
Choose step size: Choose step sizes that balance coverage and memory overhead. Use denser steps in a smaller number of token ranges, and a fixed step (e.g., 256) for larger ranges.
Manage trade-offs: more capture points reduce padding but increase memory use and can lower max concurrency; fewer points save memory but increase padding and compute cost.
Even with Piecewise CUDA Graph enabled, you may still observe bubbles in the context (prefill) phase, primarily due to the attention operator’s substantial host-side overhead.
Known Issue#
Torch compile cannot work with multi-ModelEngine config.
Speculative Decoding in Two-Model Style
speculative_config:
decoding_type: "MTP"
mtp_eagle_one_model: False # Not supported
speculative_config:
decoding_type: "Eagle"
eagle3_one_model: False # Not supported
Multimodal Model Family
Development Guide#
Background Knowledge#
Currently, TRT-LLM mainly relies on torch.compile fullgraph mode to enable Piecewise CUDA Graph feature, which means all the operations in the model must be recognized by torch.compile.
Custom Op#
For ops that cannot be represented by a torch native op, developers need to wrap them into a custom op so that they can work properly with torch.compile. A custom op mainly contains two parts: Op forward implementation & Fake kernel.
Op forward implementation: Define how this op does forward calculation. Including custom CUDA kernel, etc.
Fake kernel: Help torch.compile to do the output tensor dtype/shape inference.
After wrapping the op into a torch custom op, the implementation is a completely black box for torch compile. Instead, torch.compile will fully rely on a fake kernel to do the tracing.
Below is a simple example of flashinfer op’s fake kernel.
@torch.library.custom_op("trtllm::flashinfer_silu_and_mul", mutates_args=())
def flashinfer_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
return silu_and_mul(x, enable_pdl=ENABLE_PDL)
@flashinfer_silu_and_mul.register_fake
def _(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x).chunk(2, dim=-1)[1].contiguous()
For more examples, please refer to tensorrt_llm/_torch/custom_ops.
Current Status#
For hot models like deepseek/qwen/lllama, we’ve already wrapped some large modules into a custom op to avoid trace failure/graph breaks and exclude output projection & MTP from torch.compile’s scope.
This means developing the inside attention custom op part, the MoE routed export part, and the MPT part don’t need to worry about complex torch.compile constraints since they are treated as a black box for Torch compile. Developers should only make sure the fake kernels of attention custom op, and routed expert are aligned with the actual implementation.
Figure 1. The current model definition for DeepSeek
Reasons to wrap attention into a large custom op:
The C++ attention op interface is too complex. The argument number exceeds the torch custom op’s limitation
MLA has a slice to dispatch the MLA ctx & gen kernel. This introduces dynamic shapes, which may introduce recompilation in the real inference
Clear the boundary of attention so that it can be easily recognized by Piecewise CUDA Graph
Use some operators that will cause a graph break and are hard to avoid
Reasons to wrap MoE into a large custom op:
Use a lot of deepep ops that didn’t wrap into custom ops
Hard to support chunked MoE since it uses loops with data-dependent iteration counts, which forces Dynamo to unroll extensively and significantly slows compilation
For the op outside of attention and MLP, the developer should obey the torch.compile constraints. E.g., layernorm, allreduce, etc…
TensorRT LLM Custom Backend#
Figure 2. TensorRT LLM Custom torch.compile Backend Overview
Above is the overview of the TensorRT LLM custom backend for torch.compile.
Torch IR Optimization#
Torch IR is the Fx graph that is directly traced by Torch Dynamo. It has several important features for us to do some graph rewriting and get information:
Preserve the operations as is: We can easily find a specific operation and then transform it to arbitrary operations. No need to deal with
auto_functionalize, etc.Preserve original variable tensor name in the Fx graph: For Piecewise CUDA Graph, it needs to find the correct
SymIntwhich represents the token number. Hence, we rely on theinput_ids’s shape to make it find theSymIntcorrectly.
ATen IR Optimization#
We get ATen IR after explicitly calling aot_module_simplified on the Fx graph. ATen IR is
In SSA format (no input mutations)
Strict subset of aten op (<250): In Torch IR, Python native add op,
torch.Tensor().add(),torch.aten.add.Tensorcould be three different ops. After the transform, they will be the same op.Guaranteed metadata information, e.g., dtype and shape propagation
On this IR level, TensorRT LLM will do the following optimization
Operation Fusion#
All fusions are located in tensorrt_llm/_torch/compilation/patterns and implemented using torch.compile’s pattern matcher. Unlike the official approach, we write source patterns directly in a lower-level IR instead of relying on tracing. This avoids:
Inadequate handling of scalars and lists:
Scalars get specialized into the traced pattern, forcing one pattern per value—impractical and non-general.
Lists are flattened, turning elements into separate input arguments, making it impossible to match the original operation.
Trace-driven pitfalls: Because it’s trace-based, the generated source patterns may not meet our needs and can introduce additional issues as we expand pattern coverage.
We mainly do the operation fusion for AllReduce & RMSNorm.
AllReduce related fusion: Fuse the following operations into one AllReduce op.
AllReduce + Residual + RMSNorm
AllReduce + Residual + RMSNorm + FP8 Quantization
AllReduce + Residual + RMSNorm + FP4 Quantization
AllReduce with User Buffer: Converts AllReduce operations to use userbuffers to avoid extra copy overhead.
We enable these fusions in torch.compile because they’re difficult to express in eager mode. For the AllReduce + RMSNorm fusion, which is cross-module, implementing it in eager mode would require moving code between modules, leading to redundant, complex, and hard-to-maintain logic.
For user buffers, torch.compile provides a global, flattened view of the model, making it easy for us to manage user buffers.
Re-inplace Optimization#
Because ATen IR is SSA, in-place operations are rewritten as out-of-place via a mutation wrapper (auto_functionalize or auto_functionalize_v2 ). That wrapper can introduce an extra tensor copy on mutates args. In a TorchInductor pipeline, later passes typically eliminate this copy, but TensorRT LLM relies on custom ops and does not use Inductor. To avoid the redundant overhead, we remove the wrapper ourselves and preserve the intended in-place update.
Auto Multi-stream#
Currently torch.compile won’t create a subgraph for user user-defined CUDA stream. Instead, it will convert it to set_stream. The set_stream op doesn’t have any consumers, so it will be removed in the Torch IR to ATen IR transformation, thus losing all the multi-stream scheduling.
To address this, we implemented an auto multi-stream scheduler:
Builds a DAG of the FX graph with explicit dependencies, including special handling for in-place ops
Computes a critical path using a rough cost model
Schedules nodes onto up to
max_num_streamsspecified by user configInsert multi-stream related custom op: since the Fx graph executes operators in list order, so we insert streaming-control operators directly into the graph. Moreover, as these operators have no users, we cannot perform dead-code elimination after multi-stream scheduling. Below is an example of multi-stream, which
trtllm.dsv3_router_gemm_op.defaultandtrtllm.silu_and_mul.default+trtllm.fp4_quantize.defaultexecute in parallel.call_function record_event trtllm.record_event (1,) {} call_function fp4_quantize_2 trtllm.fp4_quantize.default (mm_1, arg18_1, 16) {} call_function getitem_9 <built-in function getitem> (fp4_quantize_2, 0) {} call_function getitem_10 <built-in function getitem> (fp4_quantize_2, 1) {} call_function nvfp4_gemm_2 trtllm.nvfp4_gemm.default (getitem_9, arg19_1, getitem_10, arg20_1, arg21_1, torch.bfloat16) {} call_function permute_2 aten.permute.default (arg17_1, [1, 0]) {} call_function record_event_1 trtllm.record_event (0,) {} call_function silu_and_mul_1 trtllm.silu_and_mul.default (nvfp4_gemm_2,) {} call_function fp4_quantize_3 trtllm.fp4_quantize.default (silu_and_mul_1, arg22_1, 16) {} call_function getitem_11 <built-in function getitem> (fp4_quantize_3, 0) {} call_function record_event_2 trtllm.record_event (4,) {} call_function getitem_12 <built-in function getitem> (fp4_quantize_3, 1) {} call_function record_event_3 trtllm.record_event (3,) {} call_function set_stream trtllm.set_stream (1,) {} call_function wait_event trtllm.wait_event (0,) {} call_function wait_event_1 trtllm.wait_event (1,) {} call_function dsv3_router_gemm_op trtllm.dsv3_router_gemm_op.default (mm_1, permute_2, None, torch.float32) {} call_function record_stream trtllm.record_stream (permute_2, 1) {} call_function record_stream_1 trtllm.record_stream (mm_1, 1) {} call_function record_event_4 trtllm.record_event (2,) {} call_function set_stream_1 trtllm.set_stream (0,) {} call_function wait_event_2 trtllm.wait_event (2,)
Piecewise CUDA Graph#
We implement Piecewise CUDA Graph execution on top of torch.compile: non-capturable regions run in eager mode, while the rest of the model is captured and replayed as CUDA Graph segments.
In the current design, we assume the attention block is the only non-capturable component. To maintain stable input pointers across segment boundaries, we convert attention to an in-place variant. Instead of allocating its own output, attention writes results into a tensor preallocated by the preceding CUDA Graph segment. This guarantees that each segment’s inputs are allocated by CUDA Graph and, therefore, stable for that segment’s capture.
Figure 3. Piecewise Runner
Notes:
Attention MUST NOT have any output. The output tensor should be allocated by CUDA Graph.
Each sub-cudagraph MUST have at least one input tensor that contains the number of tokens in the shape.
Only allow dynamic shape for
num_of_tokensdim.
Common Trace Failure#
Custom op fake kernel: For every custom op, developers must implement a correct fake kernel. Make sure to update the corresponding fake kernel when the custom op is changed
Dynamic Iteration Number Loop: This is technically not a trace failure, but it will introduce long-time tracing that is generally not acceptable. When torch.compile tries to convert PyTorch modeling code to Fx graph, it will try to unroll the loop. For a loop that has a large and dynamic loop number with a large loop body, the tracing process will take a long time to do the unrolling.
If the IO of the loop can be easily written into a custom op format, try to replace it with a custom op
If the loop num is unchanged during the whole inference service lifetime, then it is ok to leave the loop as is. (e.g., Model decoder layer loop)
Graph Break#
Use unsupported operators
python native operators:
print,sys.intern(), etc.pybind/nanobind operators
Solution: Wrap them to torch’s custom op. For complex operators like attention that exceed the argument limit of PyTorch’s custom-op interface, wrap them in a higher-level module to reduce the argument count.
Some of the torch operators:
torch.nonzeros(): Produce data-dependent dynamic shape tensortorch.sym_min:SymIntaware mintorch.Tensor.tolist(),torch.Tensor.item()Solution: Use them inside a custom op if these operators don’t get involved in producing the custom op’s output tensor.
Use a custom object’s method: For a class like mapping config, we cannot directly use its method like has_pp() in the model forward.
Solution: We should convert it to a bool in the model init and use the bool.
class Mapping(object): def __init__(self, ...): ... def has_pp(self): # Cannot use this method in torch.compile return self.pp_size > 1
Data Dependent Control(DDC) flow involved in code
Solution: Try to avoid DDC in the code. Try to pre-compute the result outside of torch.compile’s scope. For the following example, try to pre-compute the
torch.sum(data)at the data preparation stage, and pass the result to theforward.
class TestCase(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, data): y = x ** 2 if torch.sum(data) >= 4: # Data Dependent Control Here! t = y else: t = y / 2 t = t + 10 return t test_case = TestCase() test_case = torch.compile(test_case, backend=Backend()) x = torch.randn(5).cuda() data = torch.ones(2, dtype=torch.int32) data[0] = 2 data[1] = 2 test_case(x, data)
Recompilation#
Try not to use data-dependent dynamic shapes in the model forward. (e.g., slice the tensor based on input value). This will introduce 0/1 specialization to the model and will possibly introduce recompile.
0/1 specialization: torch.compile will recompile the model if a dynamic tensor’s dim equals 0 or 1. In the worst case, it will recompile 3 times for 1 dimension: 0,1, >2
For an int argument that would change during runtime, use
SymIntrather than int in the C++ custom op definition. Otherwise, it will trigger a recompile when the value changes.TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def("allgather(Tensor input, SymInt[]? sizes, int[] group) -> Tensor"); m.def("allgather_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]"); }
Some recompiles that are hard to aware:
python native
min(list),max(list): it will recompile when the list elements changesControl Flow based on dynamic shape
Next power of two: Previously, we used
bit_length()to implement the next power of 2 function. However, it will cause a recompile for every int value. Now rewrite the code to be torch.compile-friendly.def next_positive_power_of_2(x: int) -> int: if x < 1: return 1 # Following code is equivalent to 1 << (x - 1).bit_length() # But this impl does not contain bit_length(), so it can be used by torch compile. # It can correctly handle 64-bit numbers, which should be enough for now. n = x - 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 n |= n >> 32 return n + 1