CPU Offloading

Note

CPU Offloading in Transformer Engine is currently available only for PyTorch. It supports all PyTorch modules, not just TE layers.

CPU offloading moves activation tensors from GPU to CPU memory during the forward pass and reloads them during backward. Transfers are asynchronous, enabling significant GPU memory savings with minimal overhead.

Unlike activation checkpointing, offloading avoids recomputation — activations are stored on CPU instead of being recalculated, making it faster when CPU-GPU bandwidth is sufficient.

Hardware Support

CPU offloading benefits greatly from fast CPU-GPU interconnects. The faster the link, the more effectively transfer time can be hidden behind computation.

Traditional PCIe System CPU RAM GPU HBM PCIe 128 GB/s GB200 Superchip NVIDIA Grace Blackwell Blackwell GPU 1 HBM NVLink C2C Grace CPU RAM NVLink C2C Blackwell GPU 2 HBM 900 GB/s per NVLink-C2C link

Figure 1. Traditional PCIe system vs GB200 Superchip with NVLink-C2C.

Traditional PCIe Gen5 x16 systems offer 128 GB/s bidirectional bandwidth between CPU and GPU, which limits offloading benefits.

With NVLink-C2C (GB200), bandwidth jumps to 900 GB/s bidirectional per link, making offloading increasingly attractive on modern NVIDIA superchips. The GB200 pairs a Grace CPU with 480 GB LPDDR5X memory and two Blackwell GPUs, each with 192 GB HBM3e (384 GB total), providing ample CPU memory for offloading activations.

Offloading/reloading consumes HBM bandwidth, which may compete with other GPU operations — even when transfers are asynchronous. This is unlikely to affect compute-bound operations like GEMMs, but the impact on memory-bound operations like quantization may be noticeable.

CPU Offloading in Transformer Engine

Transformer Engine supports CPU offloading of activations for sequential models. A model is considered sequential if it satisfies the following conditions:

  1. The model is a sequence of layers: x₁ = Layer₁(x₀), x₂ = Layer₂(x₁), …, xₙ = Layerₙ(xₙ₋₁). The layers may be any PyTorch modules, not just TE layers.

  2. Each intermediate tensor xᵢ is used only as input to the next layer (not elsewhere in the model).

  3. xᵢ is only needed as input to Layerᵢ₊₁’s backward pass and can be freed once that pass completes.

Most LLM architectures (stacked Transformer blocks) satisfy these conditions.

x₀ Layer 1 x₁ Layer 2 x₂ Layer 3 ··· Layer N xₙ

Figure 2. Sequential model: xᵢ₊₁ = Layerᵢ₊₁(xᵢ). Each layer consumes only the output of the previous one.

The example below shows how to offload activations for a sequence of torch.nn.Linear layers using the default scheduling algorithm:

import torch
from transformer_engine.pytorch import get_cpu_offload_context

# Setup
num_layers = 12
offloaded_layers = 3
layers = [torch.nn.Linear(1024, 1024).cuda() for _ in range(num_layers)]
x = torch.randn(16, 1024, 1024, device="cuda")

# Get offloading context and sync function
cpu_offload_context, sync_function = get_cpu_offload_context(
    enabled=True,
    model_layers=num_layers,
    num_layers=offloaded_layers,
)

# Forward pass
for i in range(num_layers):
    # Context manager captures tensors saved for backward.
    # These tensors will be offloaded to CPU asynchronously.
    with cpu_offload_context:
        x = layers[i](x)

    # sync_function must be called after each layer's forward pass.
    # This cannot be done inside the context manager because
    # it needs the output tensor after the layer has finished.
    x = sync_function(x)

loss = x.sum()
loss.backward()

Let’s take a look at the API in detail:

def get_cpu_offload_context(
    enabled: bool = False,
    num_layers: Optional[int] = 1,
    model_layers: int = 1,
    manual_synchronization: bool = False,
    offload_stream: Optional[torch.cuda.Stream] = None,
    # ... (legacy parameters omitted, see :func:`get_cpu_offload_context`)
) -> Union[Tuple[ContextManager, Callable], Tuple[ContextManager, Callable, ManualOffloadSynchronizer]]:
    ...

The model_layers parameter must always be set to the total number of layers in the model. There are two modes of operation:

  1. Default scheduling — set num_layers to the number of layers to offload. The algorithm automatically schedules offload/reload operations to overlap with computation.

  2. Manual synchronization — set manual_synchronization=True (num_layers is ignored in this mode). This mode provides explicit control over when to start offload/reload using the returned ManualOffloadSynchronizer.

The transformer_engine.pytorch.get_cpu_offload_context() function returns:

  • context manager — wraps each layer’s forward pass to intercept tensors saved for backward.

  • sync function — registers a backward hook on the output tensor to trigger activation reload.

  • ManualOffloadSynchronizer (only in manual mode) — provides explicit control over offload/reload.

The usage pattern for default scheduling is:

cpu_offload_context, sync_function = get_cpu_offload_context(...)

for layer in layers:
    with cpu_offload_context:
        x = layer(x)
    x = sync_function(x)

Default Offloading Scheduling

Default scheduling is enabled when manual_synchronization=False (the default). The num_layers parameter must be specified to set the number of layers to offload. The algorithm then automatically determines when to offload and reload activations to maximize overlap with computation.

For num_layers layers offloaded of model_layers layers:

  • First num_layers layers are offloaded to CPU.

  • Offloading starts as soon as tensors are saved for backward — it does not wait for the layer’s forward pass to complete.

  • At most (model_layers - num_layers) sets of activations are on GPU at any time; both compute and reload may be stalled to enforce this limit.

  • Reloading must complete by the time the tensor is needed for the layer’s backward pass.

  • num_layers must be at most model_layers - 1 (setting it to model_layers raises an assertion error). However, model_layers - 1 leaves only 1 activation set on GPU at a time — compute and transfers cannot overlap, and a warning is raised. For full overlap, use model_layers - 2 or less.

Specifying a low enough num_layers enables full overlap of computation and offload/reload. The following two scenarios illustrate this — one with full overlap, and one with stalls.

Model (model_layers = 5) Layer 1 Layer 2 Layer 3 Layer 4 Layer 5 num_layers = 2 (offloaded) Forward Pass compute stream offload stream Layer 1 fwd Layer 2 fwd Layer 3 fwd Layer 4 fwd Layer 5 fwd Layer 1 offload Layer 2 offload Backward Pass compute stream reload stream Layer 5 bwd Layer 4 bwd Layer 3 bwd Layer 2 bwd Layer 1 bwd Layer 2 reload Layer 1 reload

Figure 3. With num_layers=2and model_layers=5, at most 3 sets of activations are on GPU. Layer 1 offloading starts during its forward pass (when the first tensor is saved for backward). Offloading fully overlaps with forward, reloading fully overlaps with backward.

When num_layers is too high, the GPU memory limit forces stalls:

Model (model_layers = 5) Layer 1 Layer 2 Layer 3 Layer 4 Layer 5 num_layers = 3 (offloaded) Forward Pass compute stream offload stream Layer 1 fwd Layer 2 fwd Layer 3 fwd wait Layer 4 fwd wait Layer 5 fwd Layer 1 offload Layer 2 offload Layer 3 offload Backward Pass compute stream reload stream Layer 5 bwd Layer 4 bwd Layer 3 bwd wait Layer 2 bwd wait Layer 1 bwd wait wait Layer 3 reload Layer 2 reload Layer 1 reload

Figure 4. With num_layers=3and model_layers=5, at most 2 sets of activations can be on GPU (5-3=2), which causes stalls. In forward, Layer 4 cannot start until Layer 2 is offloaded, otherwise there would be 3 sets of activations on GPU (Layers 2, 3, 4). In backward, Layer 3 cannot start immediately — its activations are still on CPU and must be reloaded first. Some tensors may finish reloading earlier, allowing parts of the layer (e.g., a sublayer) to run while the rest waits. The same applies to Layers 2 and 1.

Manual Synchronization

For custom scheduling, set manual_synchronization=True. Optionally, pass a custom offload_stream for fine-grained synchronization. This mode returns a ManualOffloadSynchronizer with explicit control over transfers.

This mode is useful when training does not follow the standard “all forwards then all backwards” pattern — for example, in pipeline parallelism. Providing a custom offload_stream enables additional synchronization logic (e.g., waiting, recording events) tailored to the specific workload.

The ManualOffloadSynchronizer object provides the following methods:

  • start_offload_layer(layer_id) — queue async GPU→CPU copies on the offload stream. Before each copy, the offload stream waits for an event recorded when that tensor was saved for backward.

  • release_activation_forward_gpu_memory(layer_id) — make the current stream wait for this layer’s offload to complete, then release GPU memory.

  • start_reload_layer(layer_id) — queue async CPU→GPU copies on the offload stream. When tensors are accessed in backward, compute stream waits for each tensor’s reload to complete.

To skip offloading for a specific layer, simply do not call any of these methods for that layer.

The example demonstrates:

  1. Forward pass: After each layer, call start_offload_layer(i) to begin async copy of layer i’s activations to CPU.

  2. Release GPU memory: Call release_activation_forward_gpu_memory(i) to free the GPU tensors. Each call waits internally for that layer’s offload to complete.

  3. Before backward: Call start_reload_layer(i) to begin async reload. The compute stream will automatically wait for each tensor to be reloaded before it’s accessed in backward.

import torch
from transformer_engine.pytorch import get_cpu_offload_context

# Setup
num_layers = 12
layers = [torch.nn.Linear(1024, 1024).cuda() for _ in range(num_layers)]
x = torch.randn(16, 1024, 1024, device="cuda")

offload_stream = torch.cuda.Stream()
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
    enabled=True,
    model_layers=num_layers,
    manual_synchronization=True,
    offload_stream=offload_stream,
)

# Forward pass - manually trigger offload after each layer
for i in range(num_layers):
    with cpu_offload_context:
        x = layers[i](x)
    x = sync_function(x)
    manual_controller.start_offload_layer(i)

# Release GPU memory (each call waits for that layer's offload to complete)
for i in range(num_layers):
    manual_controller.release_activation_forward_gpu_memory(i)

# Start reloading before backward
for i in range(num_layers - 1, -1, -1):
    manual_controller.start_reload_layer(i)

# Backward pass
loss = x.sum()
loss.backward()

CPU Offloading and CUDA Graphs

CPU offloading works with CUDA graphs — async copies and stream synchronization are GPU operations that can be captured and replayed, even when accessing pinned CPU memory (via PCIe DMA, without CPU involvement).

Note

We recommend capturing the entire forward and backward pass in a single graph. Async copy operations (offload/reload) must complete within the same graph where they started. If the graph ends before copies finish, PyTorch will block waiting for them, defeating the purpose of graph capture.

import torch
from transformer_engine.pytorch import get_cpu_offload_context, make_graphed_callables

# Setup
num_layers = 12
offloaded_layers = 3
layers = [torch.nn.Linear(1024, 1024).cuda() for _ in range(num_layers)]

# Enable offloading for CUDA graphs
cpu_offload_context, sync_function = get_cpu_offload_context(
    enabled=True,
    model_layers=num_layers,
    num_layers=offloaded_layers,
)


# Wrap layers in a module that uses offloading
class OffloadedModel(torch.nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = torch.nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.layers:
            with cpu_offload_context:
                x = layer(x)
            x = sync_function(x)
        return x


model = OffloadedModel(layers)
sample_input = (torch.randn(16, 1024, 1024, device="cuda"),)

# Create graphed callable (warmup is handled internally)
graphed_model = make_graphed_callables(model, sample_input)

# Use the graphed model
x = torch.randn(16, 1024, 1024, device="cuda")
out = graphed_model(x)
out.sum().backward()

Note

In PyTorch versions prior to 2.11, CPU offloading with CUDA graphs required passing retain_pinned_cpu_buffers=True to get_cpu_offload_context(). The root cause was that torch.empty with pinned CPU memory was not supported inside CUDA graph capture — buffers had to be pre-allocated and reused across iterations to avoid invalidating DMA addresses captured in the graph. This was fixed in pytorch#167507 (merged December 2025, shipping in PyTorch 2.11). On PyTorch 2.11+, retain_pinned_cpu_buffers is no longer needed.

Caveats

Warning

Heuristic activation detection:

CPU Offloading is implemented using PyTorch saved tensors hooks. PyTorch saves various tensors for backward — not just activations, but also weights and other data.

Activation detection is heuristic. A CUDA tensor is offloaded if it:

  • has at least 256×1024 elements (~1 MB for float32),

  • is not a torch.nn.Parameter,

  • is not marked with mark_not_offload().

Additionally, non-contiguous tensors are skipped to avoid memory layout changes (see below). For TE layers, tensors that should not be offloaded are manually excluded. For non-TE layers, no such exclusions exist, so some tensors may remain pinned in GPU memory even after being copied to CPU (e.g., if the layer stores references in ctx), resulting in wasted bandwidth with no memory savings.

To exclude specific tensors from offloading, use mark_not_offload():

from transformer_engine.pytorch import mark_not_offload
mark_not_offload(tensor)

Warning

Memory layout changes:

Offloading/reloading can change tensor memory layout and relations:

  1. Views of the same storage may be restored as separate allocations.

  2. Adjacent tensors may not be adjacent after reload.

CUDA kernels that rely on specific memory layout may produce unexpected results. To mitigate (1), non-trivial views are excluded from offloading by default. TE attention kernels are an exception — they use internal handling that is tested and supported. Issue (2) is not mitigated — custom kernels that assume adjacent tensors share contiguous memory may still fail.

If you encounter layout-related issues, use mark_not_offload() to exclude problematic tensors from offloading.