Source code for megatron.energon.wrappers.gc_dataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import gc
from typing import Any, Dict, Generic, Iterator, TypeVar

import torch
import torch.utils.data
import torch.utils.data.dataloader
from torch.distributed._shard.sharded_tensor import ShardedTensorBase
from torch.distributed.distributed_c10d import reduce_op

from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset

T_sample = TypeVar("T_sample")

_frozen_cuda_tensors = set()
_frozen_cuda_tensors_initialized = False


GC_DEFAULT_EVERY_N_ITER = 10


class GcFreezeError(RuntimeError):
    pass


def gc_init_worker(worker_id: int):
    """This function should be called by any forked worker process that uses CUDA.
    It should be called as early as possible in the worker process, ideally in
    the worker_init_fn of the DataLoader.

    By keeping a reference to all CUDA tensors in the worker process, we can
    prevent the forked tensors from being garbage collected."""

    global _frozen_cuda_tensors_initialized, _frozen_cuda_tensors

    num_tensors = 0
    for o in gc.get_objects():
        try:
            if o is not reduce_op:
                if isinstance(o, torch.Tensor):
                    if isinstance(o, ShardedTensorBase) or o.is_cuda:
                        # Calling .is_cuda or any hasattr on ShardedTensor will raise an error
                        # Hence, o.is_cuda is only called if o is not a ShardedTensor (in the if above)

                        _frozen_cuda_tensors.add(o)
                        num_tensors += 1
                elif isinstance(o, torch.utils.data.dataloader._MultiProcessingDataLoaderIter):
                    o._shutdown = True
        except ReferenceError:
            # Can happen if the object is a weakref proxy, don't care
            pass

    _frozen_cuda_tensors_initialized = True


[docs] class GcDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """Applies a garbage collection step. This is needed, because python garbage collection does not work well with very large objects, such as tensors. This case happens, if there are a few hundred objects created and released every epoch (some of them being (large) tensors), where a lot of them are alive at the same time, but released later. In that case, those objects may end up in gc generation 2, where they may live until a lot of objects have been created, until automatic garbage collection of gen2 is actually triggered. To avoid this memory leak, `gc.collect()` is best to be called regularly. In addition, if `gc.freeze()` is used before the loop, it will remove the objects currently alive from garbage collection checks, thus making the gc faster. """ every_n_iter: int freeze: bool def __init__( self, dataset: SavableDataset[T_sample], *, worker_config: WorkerConfig, every_n_iter: int = GC_DEFAULT_EVERY_N_ITER, freeze: bool = True, ): """Construct a GcDataset, which applies garbage collection every `every_n_iter` iterations. Args: dataset: The input dataset to wrap every_n_iter: How often to perform garbage collection freeze: If true, run `gc.freeze()` before the loop, and `gc.unfreeze()` after the loop. This will speed up garbage collection, but will keep all initially alive objects alive until the end of the loop (i.e. if the dataset state was restored, that state will be saved as well). """ super().__init__(dataset, worker_config=worker_config) self.every_n_iter = every_n_iter self.freeze = freeze
[docs] def reset_state_own(self) -> None: return
def __len__(self): return len(self.dataset) def __iter__(self) -> Iterator[T_sample]: in_worker = torch.utils.data.get_worker_info() is not None if in_worker and not _frozen_cuda_tensors_initialized: raise GcFreezeError( "You are using GcDataset with multiple workers, but forgot to call gc_init_worker() in at least one forked worker process." ) if self.freeze: gc.collect() gc.freeze() try: iter = 0 for sample in self.dataset: yield sample iter += 1 if iter >= self.every_n_iter: gc.collect() iter = 0 finally: if self.freeze: gc.unfreeze()
[docs] def config(self) -> Dict[str, Any]: # This is transparent, no config to be saved (it does not affect the dataset) return self.dataset.config()
def __str__(self): return f"GcDataset(every_n_iter={self.every_n_iter}, dataset={self.dataset}, freeze={self.freeze})"