# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import gc
import multiprocessing
import queue
import threading
import time
import traceback
import weakref
from functools import partial
from typing import (
Any,
ClassVar,
Dict,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
import numpy as np
import torch
import torch.distributed
import torch.multiprocessing
from torch.utils.data import DataLoader, IterableDataset
from megatron.energon.dataclass_slots import dataclass_slots
from megatron.energon.errors import deprecated
from megatron.energon.flavors.base_dataset import (
FlexState,
SavableDataset,
State,
add_sample_restore_key,
)
from megatron.energon.rng import SystemRng, SystemRngState
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
from megatron.energon.wrappers.batch_dataset import BatchDataset
from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker
from megatron.energon.wrappers.log_sample_dataset import default_get_keys
T = TypeVar("T")
def _init_worker(seed_per_worker: List[int], worker_id: int):
"""Initializes the the worker process.
Sets the random seeds and prepare EPath for the forked worker process.
"""
gc_init_worker(worker_id)
worker_seed = seed_per_worker[worker_id]
SystemRng.seed(worker_seed)
class SimpleSavableDatasetWrapper(BaseWrapperDataset[T, Tuple[int, int, T]], Generic[T]):
"""Wrapper for non-multiprocessing savable datasets. Restarts the inner dataset. This class is
not intended to be used directly."""
_state_restored: bool
_sample_index: int
_savable_fields = ("_sample_index",)
def __init__(self, dataset: SavableDataset[T], worker_config: WorkerConfig):
super().__init__(dataset, worker_config=worker_config)
self.reset_state_own()
def reset_state_own(self) -> None:
self._sample_index = 0
self._state_restored = False
def __len__(self):
return len(self.dataset)
def __iter__(self):
self._state_restored = True
worker_id = self.worker_config.rank_worker_id()
global_worker_id = self.worker_config.global_worker_id()
while self._state_restored:
self._state_restored = False
self.worker_config.worker_activate(self._sample_index)
worker_active = True
try:
for src_data in self.dataset:
self.worker_config.worker_deactivate()
worker_active = False
sample_index = self._sample_index
src_data = add_sample_restore_key(
src_data, global_worker_id, sample_index, src=self
)
self._sample_index += 1
yield worker_id, sample_index, src_data
if self._state_restored:
# Restart iterator after restore
break
self.worker_config.worker_activate(self._sample_index)
worker_active = True
finally:
if worker_active:
self.worker_config.worker_deactivate()
def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T:
id, global_worker_id, sample_idx = index[:3]
assert id == type(self).__name__
index = index[3:]
self.worker_config.worker_activate(sample_idx, override_global_rank=global_worker_id)
try:
return add_sample_restore_key(
self.dataset.restore_sample(index),
global_worker_id,
sample_idx,
src=self,
)
finally:
self.worker_config.worker_deactivate()
def config(self) -> Dict[str, Any]:
return self.dataset.config()
def __str__(self):
return f"SimpleSavableDatasetWrapper(dataset={self.dataset})"
@dataclass_slots
class SavableDatasetState(State):
"""State of the dataset wrapper. It stores the global random states and the index of the next
sample to be returned from the dataset. This class is not intended to be used directly, but by
:class:`megatron.energon.SavableDatasetWrapper`."""
#: The state of all the system random number generators
rng: SystemRngState
#: The state of the savable dataset
dataset_state: FlexState
#: Index of the next sample to be returned from the dataset
sample_index: int
def __repr__(self):
return f"SavableDatasetState(rng={self.rng!r}, sample_index={self.sample_index})"
@dataclass_slots
class SavableCheckpoint:
"""Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. An instance is created
regularly to be able to save the state of the dataset wrapper before the currently emitted
sample.
"""
#: The state of the wrapper
state: Optional[SavableDatasetState]
#: The time at which the checkpoint was created
checkpoint_time: float
#: Index of the next sample to be returned from the dataset after restoring the checkpoint
sample_index: int
@dataclass_slots
class SavableDatasetCheckpoint(State):
"""Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. The checkpoint state
represents a state before that checkpoint, with an offset (i.e. samples to be skipped)."""
#: The state of the wrapper at the sample index when the checkpoint was created.
state: Optional[SavableDatasetState]
#: Offset of the checkpoint to the actual sample index to be restored.
offset: int
class SavableDatasetWrapper(IterableDataset[Tuple[int, int, T]], Generic[T]):
"""Internal class for wrapping a savable dataset for a worker process. Provides communication
with the :class:`megatron.energon.SavableDataLoader`. This class is not intended to be used directly.
See :class:`megatron.energon.SavableDataLoader` for more information."""
#: The wrapped dataset
dataset: SavableDataset[T]
#: The configuration of the worker process
worker_config: WorkerConfig
#: The time interval in seconds to wait at minimum between two checkpoints
checkpoint_every_sec: float
#: The minimum number of samples to be emitted between two checkpoints. Should be `number of
# workers * 2`.
checkpoint_every_min_n_samples: int
#: The number of checkpoints to keep in memory, before discarding. Should be 2.
n_checkpoints: int
#: The queue of the worker process to receive commands from the `SavableDataLoader`.
_cmd_queues: List[torch.multiprocessing.Queue]
#: The queue of the worker process to send results to the `SavableDataLoader`.
_result_queues: List[torch.multiprocessing.Queue]
_sample_index: int = 0
_worker_offset: int = 0
_last_checkpoints: List[SavableCheckpoint]
_workers_restore_from: List[Optional[SavableDatasetState]] = list()
_workers_skip_samples: List[int]
_running: bool = False
_command_lock: Optional[threading.RLock] = None
_cmd_thread: Optional[threading.Thread] = None
def __init__(
self,
dataset: SavableDataset[T],
worker_config: WorkerConfig,
checkpoint_every_sec: float,
checkpoint_every_min_n_samples: int,
n_checkpoints: int = 2,
*,
cmd_queues: List[torch.multiprocessing.Queue],
result_queues: List[torch.multiprocessing.Queue],
):
"""
Create the savable dataset wrapper for multiprocessing data loading.
Args:
dataset: The dataset to wrap
worker_config: The worker config as used by all datasets
checkpoint_every_sec: The time interval in seconds to wait at minimum between two
checkpoints.
checkpoint_every_min_n_samples: The minimum number of samples to be emitted between
two checkpoints. Should be `number of workers * 2`.
n_checkpoints: Number of checkpoints to keep.
cmd_queues: The command queues for communicating with the worker processes.
result_queues: The result queues for communicating with the worker processes.
"""
num_workers = max(worker_config.num_workers, 1)
self.dataset = dataset
self.worker_config = worker_config
self.checkpoint_every_sec = checkpoint_every_sec
self.checkpoint_every_min_n_samples = checkpoint_every_min_n_samples
self.n_checkpoints = n_checkpoints
self._last_checkpoints = [
SavableCheckpoint(state=None, checkpoint_time=time.perf_counter(), sample_index=-1)
]
self._workers_skip_samples = [0] * num_workers
self._cmd_queues = cmd_queues
self._result_queues = result_queues
@staticmethod
def _command_thread(self: "SavableDatasetWrapper"):
"""The internal thread, which processes the command and result queues. This thread is
static, because `self` is actually passed as weakref proxy, to avoid keeping the dataset
alive via the thread.
"""
# print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread starting")
assert self._command_lock is not None
try:
while self._running:
try:
cmd_args = self._cmd_queues[self._worker_id].get(timeout=1)
except queue.Empty:
continue
# print(f"recv cmd {cmd_args}")
with self._command_lock:
cmd = cmd_args[0]
if cmd is None:
break
try:
fn = getattr(self, cmd)
self._result_queues[self._worker_id].put(
{self._worker_id: fn(*cmd_args[1:])}
)
# print(f"result sent")
except Exception as e:
traceback.print_exc()
self._result_queues[self._worker_id].put({self._worker_id: e})
# print(f"exc sent")
except BaseException:
traceback.print_exc()
raise
finally:
pass
# print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread closing")
def __len__(self):
return len(self.dataset)
def __del__(self):
if self._cmd_thread is not None:
# print(f"{id(self)}:{multiprocessing.current_process().ident} Closing cmd thread")
self._running = False
self._cmd_thread.join()
self._command_lock = None
self._cmd_thread = None
# print(f"{id(self)}:{multiprocessing.current_process().ident} Cmd thread closed")
def __iter__(self):
# First: Set the worker offset globally for the current worker
WorkerConfig.worker_id_offset = self._worker_offset
self._worker_id = self.worker_config.rank_worker_id()
global_worker_id = self.worker_config.global_worker_id()
if self._cmd_thread is None:
self._running = True
self._command_lock = threading.RLock()
weakref_self = weakref.proxy(self)
self._cmd_thread = threading.Thread(
target=SavableDatasetWrapper._command_thread,
name="command_thread",
args=(weakref_self,),
daemon=True,
)
self._cmd_thread.start()
# atexit.register(lambda: weakref_self.__del__())
try:
assert self._command_lock is not None
with self._command_lock:
if self._workers_restore_from:
my_state = self._workers_restore_from[self._worker_id]
my_ds_state = my_state.dataset_state
assert my_state is not None
if my_ds_state is None:
self.dataset.reset_state_deep()
else:
self.dataset.restore_state(my_ds_state)
self._restore_state(my_state)
self._workers_restore_from = []
else:
# Store the initial state of the worker if we stop before the first sample
self._store_checkpoint()
# If skipping, also restart the iterator to reach the start of the restored
# checkpoint
last_was_skip = True
while last_was_skip:
dataset_has_samples = False
self.worker_config.worker_activate(self._sample_index)
worker_active = True
try:
for src_data in self.dataset:
self.worker_config.worker_deactivate()
worker_active = False
dataset_has_samples = True
if self._workers_skip_samples[self._worker_id] > 0:
# Skip ahead to reach the start of the restored checkpoint
# print(f"Skip [{self._worker_id}:{self._sample_index}] {src_data}")
self._workers_skip_samples[self._worker_id] -= 1
self._sample_index += 1
last_was_skip = True
self.worker_config.worker_activate(self._sample_index)
worker_active = True
continue
last_was_skip = False
sample_index = self._sample_index
add_sample_restore_key(
src_data, global_worker_id, sample_index, src=self
)
self._sample_index += 1
self._store_checkpoint()
try:
self._command_lock.release()
# print(f"{id(self)}:{multiprocessing.current_process().ident} Lock released")
# Commands may be executed only when data was yielded, not during
# iteration fetching.
# print(f"Yield next data [{self._worker_id}:{sample_index}] {src_data}")
yield self._worker_id, sample_index, src_data
finally:
# print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquiring")
self._command_lock.acquire()
# print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquired")
self.worker_config.worker_activate(self._sample_index)
worker_active = True
finally:
if worker_active:
self.worker_config.worker_deactivate()
# If the dataset is empty, don't try again and again
if not dataset_has_samples:
break
finally:
# print(f"{id(self)}:{multiprocessing.current_process().ident} Worker iter closing")
# Always store a final checkpoint (it's likely to be saved)
self._store_checkpoint(force=True)
def _store_checkpoint(self, force: bool = False) -> None:
"""
Internally create a checkpoint for the current state. This is required to store states
from the past, which have already been yielded here, but not yet been retrieved from the
intermediate queues.
Args:
force: If true, ignore time or frequency condition.
"""
if (
force
or (
self._last_checkpoints[-1].checkpoint_time + self.checkpoint_every_sec
< time.perf_counter()
and self._last_checkpoints[-1].sample_index + self.checkpoint_every_min_n_samples
<= self._sample_index
)
or self._sample_index <= 1
):
# print(f"Storing checkpoint at {self._worker_id}:{self._sample_index}")
self._last_checkpoints.append(
SavableCheckpoint(
state=self._save_state(),
checkpoint_time=time.perf_counter(),
sample_index=self._sample_index,
)
)
if len(self._last_checkpoints) > self.n_checkpoints:
self._last_checkpoints.pop(0)
def _save_state(self) -> SavableDatasetState:
"""Saves the internal state"""
(
np_tp,
np_state,
pos,
has_gauss,
cached_gaussian,
) = np.random.get_state()
return SavableDatasetState(
rng=SystemRng.save_state(),
dataset_state=self.dataset.save_state(),
sample_index=self._sample_index,
)
def _restore_state(self, state: SavableDatasetState) -> None:
"""Restores the internal worker state"""
assert torch.utils.data.get_worker_info() is not None, "Can only restore in worker process"
if state.rng is None:
SystemRng.seed(torch.initial_seed() & 0xFFFFFFFF)
else:
SystemRng.restore_state(state.rng)
self._sample_index = state.sample_index
self._last_checkpoints = [
SavableCheckpoint(
state=self._save_state(),
checkpoint_time=time.perf_counter(),
sample_index=self._sample_index,
)
]
def get_checkpoint(self, last_sample_indexes: List[int]) -> SavableDatasetCheckpoint:
"""
Get a checkpoint given the last emitted sample indexes for all workers.
Args:
last_sample_indexes: The last emitted sample indexes for all workers.
Returns:
The found checkpoint including the offset to the next sample index
"""
sample_index = last_sample_indexes[self._worker_id] + 1
for checkpoint in reversed(self._last_checkpoints):
if checkpoint.sample_index <= sample_index:
# print(f"Found cp for {sample_index} at {checkpoint.sample_index}")
return SavableDatasetCheckpoint(
state=checkpoint.state,
offset=sample_index - checkpoint.sample_index,
)
raise ValueError("No checkpoint found")
def restore_checkpoint(
self,
worker_states: Optional[List[SavableDatasetCheckpoint]],
worker_offset: int,
) -> None:
"""
Restores the merged checkpoint from all worker processes.
Args:
worker_states: The state to restore for each worker
worker_offset: The offset of the last worker which has emitted a sample. This will be
set in all worker processes to ensure the right worker starts as first.
"""
assert torch.utils.data.get_worker_info() is None, "Cannot restore in worker process"
num_workers = max(self.worker_config.num_workers, 1)
if worker_states is None:
self._workers_restore_from = []
assert worker_offset == 0
self._worker_offset = 0
self._workers_skip_samples = [0] * num_workers
else:
assert isinstance(worker_states, list)
assert isinstance(worker_states[0], SavableDatasetCheckpoint)
self._worker_offset = worker_offset
# Tear the state_list apart (which has len=num_workers)
# and store the states in the internal arrays
self._workers_restore_from = [state.state for state in worker_states]
self._workers_skip_samples = [state.offset for state in worker_states]
def can_restore_sample(self) -> bool:
return self.dataset.can_restore_sample()
def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T:
id, global_worker_id, sample_idx = index[:3]
assert id == type(self).__name__
index = index[3:]
self.worker_config.worker_activate(sample_idx, override_global_rank=global_worker_id)
try:
return add_sample_restore_key(
self.dataset.restore_sample(index),
global_worker_id,
sample_idx,
src=self,
)
finally:
self.worker_config.worker_deactivate()
def config(self) -> Dict[str, Any]:
return self.dataset.config()
def __str__(self):
return f"SavableDatasetWrapper(dataset={self.dataset})"
@dataclass_slots
class SavableDataLoaderState(State):
"""Saved state of the :class:`megatron.energon.SavableDataLoader`. Contains the state for all worker
processed of a single rank."""
#: The internal state of the dataset (for each worker process)
worker_states: List[Union[SavableDatasetCheckpoint, FlexState]]
#: Which worker will be the next to emit a sample. Used to restore the proper order
next_worker_id: int
#: The micro batch size that was used, if available.
#: On restore, this is used to compare the new and old micro batch size.
micro_batch_size: Optional[int]
[docs]
class SavableDataLoader(DataLoader[T], Generic[T]):
"""DataLoader that supports saving and restoring the state of the dataset.
When restoring, the dataloader and dataset must be instantiated with the exactly same
parameters.
How this works (for no worker processes)
----------------------------------------
1. The state of the dataset is saved using :meth:`megatron.energon.SavableDataset.save_state`
2. (for compatibility) The state of the dataset is converted to using inner arrays using
:meth:`megatron.energon.SavableDataset.merge_states`.
3. The state can be restored using :meth:`megatron.energon.SavableDataset.restore_state` given the
previously saved (and merged) state.
How this works (for worker processes)
-------------------------------------
- First issue is, that worker processes work with internal queues between processes to pass
loaded samples to the main process (also to perform collating). This means that the whole
state of the dataset is not directly accessible from the main process.
- To solve this issue, the dataset regularly saves a checkpoint of its state to be able to
resume from that state (and skip the samples that have already been yielded).
- To have a consistent state, the sample index from the latest yielded samples is saved for all
worker instances. Thus, the main process knows exactly which sample indexes should come next
from which worker.
- Internally, pytorch iterates through the workers in order to retrieve the next worker's
samples. Unfortunately, that next worker index cannot be restored in pytorch's dataloader,
thus the workers are shifted internally by that offset
(see :attr:`megatron.energon.WorkerConfig.worker_id_offset`).
1. The dataset is wrapped in a :class:`megatron.energon.SavableDatasetWrapper`. This allows the main
process to communicate with the worker and send commands to the workers and retrieve the
results.
2. The state of the dataset is saved using
:meth:`megatron.energon.SavableDatasetWrapper.get_checkpoint`. This gives the last checkpoint
from the requested sample index and stores the offset (i.e. number of samples to skip) from
that checkpoint.
3. The state is merged using :meth:`megatron.energon.SavableDatasetWrapper.merge_checkpoints`. This
merges the states of all workers and returns a single state that can be used to restore the
state of the dataset.
3. The state can be restored using :meth:`megatron.energon.SavableDatasetWrapper.restore_state`
before a worker is started, such that all workers initially receive the same state array.
The worker firstly sets the worker index offset, then uses its (shifted) own index to get its
required state from the merged state array.
"""
#: The worker config
worker_config: WorkerConfig
#: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper`
dataset: Union[SavableDatasetWrapper[T], SimpleSavableDatasetWrapper[T]]
#: The global ID counter
_next_id: ClassVar[int] = 0
#: Class instance id
id: int = 0
#: The queues used to send commands to the workers
cmd_queues: List[torch.multiprocessing.Queue]
#: The queues used to receive results from the workers
result_queues: List[torch.multiprocessing.Queue]
#: Instance of the current data iterator. There shall be only one active iterator, such that the
# dataset is not iterated multiple times in parallel. The state will proceed.
_persistent_iterator: Optional[Iterator[T]] = None
#: The index of the current worker. -1 if not started yet.
_worker_sample_counters: List[int]
#: Id of the next worker to retrieve data from
_next_worker_id: int = 0
#: Global index of the last yielded sample
_global_sample_idx: int = 0
#: Current iterator index of the last yielded sample
_sample_idx: int = 0
def __init__(
self,
dataset: SavableDataset[T],
*,
checkpoint_every_sec: float = 60,
checkpoint_every_min_n_samples: Optional[int] = None,
n_checkpoints: Optional[int] = None,
gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER,
gc_freeze_at_start: bool = True,
prefetch_factor: int = 2,
):
"""
Create the dataloader supporting saving and restoring the state.
Args:
dataset: The dataset to load.
worker_config: The worker config to use
checkpoint_every_sec: This is the time in seconds after which a checkpoint is saved.
It may take the same duration to restore a checkpoint, but introduces additional
overhead during reading data from the dataset, so this should be chosen accordingly.
Only applies if using workers.
checkpoint_every_min_n_samples: Overwrites the minimum number of samples between
checkpoints. Defaults to `number of workers * 2`. Only applies if using workers.
n_checkpoints: The number of checkpoints to keep in memory. Only applies if using
workers. If None, computes a suitable value.
gc_collect_every_n_steps: The number of steps after which the garbage collector is
called. As we're usually handling large (but few) tensors here, and the python
garbage collection is already full of objects just by importing, this can improve
the memory footprint quite a lot, and may even be necessary to avoid memory
overflow.
gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker
processes. This improves the garbage collection performance by a lot.
In rare cases, this may cause issues and can be disabled. Keep enabled if you
experience no issues.
"""
self.worker_config = dataset.worker_config
self.id = self.next_id()
if gc_collect_every_n_steps > 0:
dataset = GcDataset(
dataset,
worker_config=self.worker_config,
every_n_iter=gc_collect_every_n_steps,
freeze=gc_freeze_at_start,
)
self.cmd_queues = [multiprocessing.Queue() for _ in range(self.worker_config.num_workers)]
self.result_queues = [
multiprocessing.Queue() for _ in range(self.worker_config.num_workers)
]
num_procs = max(self.worker_config.num_workers, 1)
if n_checkpoints is None:
n_checkpoints = prefetch_factor * num_procs + 1
if self.worker_config.num_workers > 0:
if checkpoint_every_min_n_samples is None:
checkpoint_every_min_n_samples = self.worker_config.num_workers * 2
dataset = SavableDatasetWrapper(
dataset,
self.worker_config,
checkpoint_every_sec=checkpoint_every_sec,
checkpoint_every_min_n_samples=checkpoint_every_min_n_samples,
n_checkpoints=n_checkpoints,
cmd_queues=self.cmd_queues,
result_queues=self.result_queues,
)
else:
dataset = SimpleSavableDatasetWrapper(dataset, self.worker_config)
self._worker_sample_counters = [-1] * num_procs
kwargs = {}
if self.worker_config.num_workers > 0:
kwargs["persistent_workers"] = True
kwargs["prefetch_factor"] = prefetch_factor
# Assert that prefetch_factor works well with num_checkpoints.
# This ensures that the oldest checkpoint is old enough to cover
# all the buffered samples in the torch dataloader.
assert prefetch_factor * num_procs + 1 <= n_checkpoints, (
"When increasing prefetch_factor, also increase n_checkpoints, so that "
"the number of checkpoints is at least as large as num_workers * prefetch_factor + 1"
)
# Compute seeds for each worker, based on current rank
seed_per_worker = [
self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers)
]
super().__init__(
dataset,
batch_size=None,
shuffle=False,
num_workers=self.worker_config.num_workers,
pin_memory=True,
worker_init_fn=partial(_init_worker, seed_per_worker),
**kwargs,
)
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "SavableLoader.__init__",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"config": dataset.config(),
}
)
[docs]
@staticmethod
def next_id() -> int:
next_id = SavableDataLoader._next_id
SavableDataLoader._next_id += 1
return next_id
def __iter__(self):
outerself = self
class InnerIterator:
"""Internal class which keeps the iterator alive across multiple `iter()` calls.
If the inner iterator is exhausted, will also exhaust and a new instance is needed.
Also saves the last sample index and the next worker id.
"""
finished: bool = False
iter_idx: int = 0
id: int
def __init__(self, iterator):
self._iterator = iterator
self.id = outerself.next_id()
if outerself.worker_config.should_log(level=1):
outerself.worker_config.worker_log(
{
"t": "SavableDataLoader.iter",
"r": outerself.worker_config.rank,
"w": None,
"id": outerself.id,
"iter_id": self.id,
}
)
# self._debugf = open(
# f"worker_samples_rank{outerself.worker_config.rank:02}_t{int(time.time())}.log", "w"
# )
def __iter__(self):
return self
def __next__(self):
try:
worker_id, sample_idx, sample = next(self._iterator)
outerself._worker_sample_counters[worker_id] = sample_idx
# If the next sample will be from the first worker, we can safely resume
outerself._next_worker_id = (worker_id + 1) % max(outerself.num_workers, 1)
# self._debugf.write(
# f"[w={worker_id}, s={sample_idx}] {self._sample_str(sample)}\n"
# )
# self._debugf.flush()
if outerself.worker_config.should_log(level=1):
keys = default_get_keys(sample)
outerself.worker_config.worker_log(
{
**{
"t": "SavableDataLoader.yield",
"r": outerself.worker_config.rank,
"w": None,
"id": outerself.id,
"iter_id": self.id,
"worker_id": worker_id,
"worker_idx": sample_idx,
"idx": outerself._sample_idx,
"iter_idx": self.iter_idx,
"global_idx": outerself._global_sample_idx,
},
**({} if keys is None else {"keys": keys}),
}
)
outerself._sample_idx += 1
outerself._global_sample_idx += 1
self.iter_idx += 1
return sample
except StopIteration:
self.finished = True
outerself._next_worker_id = 0
if outerself.worker_config.should_log(level=1):
outerself.worker_config.worker_log(
{
"t": "SavableDataLoader.StopIteration",
"r": outerself.worker_config.rank,
"w": None,
"id": outerself.id,
"iter_id": self.id,
}
)
raise
if self.num_workers > 0:
# Always keep same iterator alive, as long as it yields data
if self._persistent_iterator is None or self._persistent_iterator.finished:
self._persistent_iterator = InnerIterator(super().__iter__())
self._sample_idx = 0
# print("New Iterator", self._persistent_iterator)
return self._persistent_iterator
else:
return InnerIterator(super().__iter__())
def _worker_command(self, *cmd_args) -> List[Any]:
"""Executes a command in all workers and returns the results."""
# print(f"cmd: {cmd_args}")
for cmd_queue in self.cmd_queues:
cmd_queue.put(cmd_args)
# print(f"waiting for res")
assert len(self.result_queues) == self.worker_config.num_workers
res = {k: v for results_queue in self.result_queues for k, v in results_queue.get().items()}
res = [res[i] for i in range(len(res))]
# print(f"res: {res}")
for r in res:
if isinstance(r, Exception):
raise r
return res
def _get_batch_size(self) -> Optional[int]:
"""Try to infer micro batch size from the dataset"""
if isinstance(self.dataset, SavableDatasetWrapper):
dataset = self.dataset.dataset
else:
dataset = self.dataset
if (
isinstance(dataset, BaseWrapperDataset)
and (bds := dataset._find_wrapped_dataset(BatchDataset)) is not None
):
assert isinstance(bds, BatchDataset)
return bds.batch_size
else:
return None
[docs]
def save_state_rank(self) -> Optional[SavableDataLoaderState]:
"""
Saves the state of the dataset for the current rank. Allows for restoring the state later
using `restore_state_rank`, given the result of this method.
Returns:
The state of the dataset.
"""
# Fetch current rank's worker's state
if self.num_workers == 0:
# No workers configured
assert isinstance(self.dataset, SimpleSavableDatasetWrapper)
worker_states = [self.dataset.save_state()]
assert self._next_worker_id == 0
elif self._persistent_iterator is None:
# Workers configured, but not started yet -> Initial state
return None
else:
# Fetch from worker processes
worker_states = self._worker_command("get_checkpoint", self._worker_sample_counters)
# Merge the states
merged_state = SavableDataLoaderState(
worker_states=worker_states,
next_worker_id=self._next_worker_id,
micro_batch_size=self._get_batch_size(),
)
# Not distributed -> return the merged state
return merged_state
[docs]
def restore_state_rank(self, state: Optional[SavableDataLoaderState]) -> None:
"""
Restores the saved state for the current rank.
Args:
state: The state to restore, as saved by `save_state_rank`.
"""
assert self._persistent_iterator is None, "Cannot restore state while workers are running"
if state is None:
# Assume initial state
return
assert isinstance(state, SavableDataLoaderState)
old_micro_batch_size = state.micro_batch_size
micro_batch_size = self._get_batch_size()
if isinstance(self.dataset, SavableDataset):
assert micro_batch_size == old_micro_batch_size, (
"Changing micro batch size is not allowed without workers"
)
assert len(state.worker_states) == 1
assert isinstance(state.worker_states[0], FlexState)
self.dataset.restore_state(state.worker_states[0])
else:
assert isinstance(self.dataset, SavableDatasetWrapper)
assert all(isinstance(s, SavableDatasetCheckpoint) for s in state.worker_states)
# Check batch sizes (before and after)
if micro_batch_size != old_micro_batch_size:
assert micro_batch_size is not None and old_micro_batch_size is not None, (
"Cannot resume with different batching mode "
"(batching to non-batching or vice versa)"
)
if micro_batch_size > old_micro_batch_size:
raise ValueError(
"Resuming with larger micro batch size is not allowed: "
f"{micro_batch_size} > {state.micro_batch_size}"
)
elif (
micro_batch_size < old_micro_batch_size
and old_micro_batch_size % micro_batch_size != 0
):
raise ValueError(
"Resuming with smaller micro batch size only allowed if the old "
f"micro batch size is a multiple of the new one: {micro_batch_size} < {state.micro_batch_size}"
)
batch_size_ratio = old_micro_batch_size // micro_batch_size
for worker_state in state.worker_states:
assert isinstance(worker_state, SavableDatasetCheckpoint)
# When resuming with a smaller micro batch size, the offset must be scaled
# up to the new micro batch size to skip the same number of samples as before.
worker_state.offset *= batch_size_ratio
self.dataset.restore_checkpoint(state.worker_states, worker_offset=state.next_worker_id)
[docs]
@deprecated(
"`save_state` is deprecated and was renamed to `save_state_global` and will be removed "
"in a future update. If you actually do not want to gather the states to a rank, use "
"`save_state_rank` instead."
)
def save_state(self, dst_rank: int) -> Optional[Sequence[Optional[SavableDataLoaderState]]]:
"""Deprecated. Use `save_state_global` (or `save_state_rank`) instead."""
return self.save_state_global(dst_rank)
[docs]
def save_state_global(
self, global_dst_rank: int
) -> Optional[Sequence[Optional[SavableDataLoaderState]]]:
"""
Saves the state of the dataset globally, collecting the state from all ranks using torch
distributed. Allows for restoring the state later using `restore_state_global`, given the
result of this method.
Typical scenario: Save the state to disk only on the `dst_rank`, the other ranks do not
save the state. Later, restore the state either only loaded on the `dst_rank` or
loading on all ranks separately using `restore_state_global`.
Note: If you want to save/restore the state per rank separately, use `save_state_rank` and
the corresponding `restore_state_rank`. Also, these do not rely on torch distributed.
Args:
global_dst_rank: The state will be gathered to this rank. The rank refers to the
global rank, not the rank within the data parallel group.
Returns:
The state of the dataset (or `None`, if not on `dst_rank`).
"""
# Fetch current rank's worker's state
merged_state = self.save_state_rank()
# Gather the merged states
if self.worker_config.world_size > 1:
output: Optional[Sequence[Optional[SavableDataLoaderState]]]
if self.worker_config.global_rank() == global_dst_rank:
output = [None] * self.worker_config.world_size
else:
# Check if the global_dst_rank is in the same group at all
if self.worker_config.data_parallel_group is not None:
try:
_ = torch.distributed.get_group_rank(
self.worker_config.data_parallel_group, global_dst_rank
)
except RuntimeError:
raise ValueError(
f"global_dst_rank {global_dst_rank} is not in the group of the current rank's worker config"
)
output = None
torch.distributed.gather_object(
merged_state,
output,
global_dst_rank,
group=self.worker_config.data_parallel_group,
)
return output
else:
# Not distributed -> return the merged state
return [merged_state]
[docs]
@deprecated(
"`restore_state` was renamed to `restore_state_global` and will be removed in a future update."
)
def restore_state(
self,
state: Optional[Sequence[Optional[SavableDataLoaderState]]],
) -> None:
"""Deprecated. Use `restore_state_global` (or `restore_state_rank`) instead."""
return self.restore_state_global(state)
[docs]
def restore_state_global(
self,
state: Optional[Sequence[Optional[SavableDataLoaderState]]],
*,
src_rank: Optional[int] = None,
) -> None:
"""
Restores the saved state from `save_state_global` (in torch distributed setup).
The global state needs be loaded on every rank that has a data loader instance.
Optionally, one can specify a src_rank and only provide the state once.
In case of multiple data parallel groups, you must provide the state once
in each data parallel group. In this case the `src_rank` is the rank within the
data parallel group.
Args:
state: The state to restore, as saved by `save_state_global`.
src_rank: The rank from which the state is broadcasted (within the data parallel group, if using DP groups).
"""
assert self._persistent_iterator is None, "Cannot restore state while workers are running"
# Only restore multi-rank if state is actually a list and we are in a distributed setup.
# Otherwise treat as single rank state.
if src_rank is None or self.worker_config.world_size == 1:
assert isinstance(state, list), "State must be a list in distributed setup"
assert len(state) == self.worker_config.world_size, (
"State must be a list of size world_size"
)
# All ranks have the state
# Select the state of the current rank
rank_state = state[self.worker_config.rank]
else:
if self.worker_config.data_parallel_group is not None:
# Only the src_rank has the state within this dp group
try:
global_src_rank = torch.distributed.get_global_rank(
self.worker_config.data_parallel_group, src_rank
)
except RuntimeError:
raise ValueError(
f"src_rank {src_rank} is not in the group of the current rank's worker config"
)
else:
# If no DP group is given, we assume the global rank is
# the same as the data parallel rank
global_src_rank = src_rank
if self.worker_config.rank != src_rank:
# Send the state to all other ranks
assert state is None
# Must still be a list of Nones
state = [None] * self.worker_config.world_size
else:
assert isinstance(state, list), "State must be a list in distributed setup"
assert len(state) == self.worker_config.world_size, (
"State must be a list of size world_size"
)
local_object = [None]
torch.distributed.scatter_object_list(
local_object,
state,
src=global_src_rank,
group=self.worker_config.data_parallel_group,
)
rank_state = local_object[0]
self.restore_state_rank(rank_state)
[docs]
def can_restore_sample(self) -> bool:
return self.dataset.can_restore_sample()
[docs]
def restore_sample(self, sample_key: Tuple[Union[str, int, tuple], ...]) -> T:
"""Restores a sample from a key. This is useful to debug the dataset."""
return self.dataset.restore_sample(sample_key)
[docs]
def config(self):
"""Get the configuration, which defines the dataset. Useful in conjunction with `save_state`
and `restore_state` to match the configuration as well."""
return {
"type": type(self).__qualname__,
"num_workers": self.num_workers,
"persistent_workers": self.persistent_workers,
"pin_memory": self.pin_memory,
"prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor,
"dataset": self.dataset.config(),
}
class BasicDataLoader(DataLoader[T], Generic[T]):
"""DataLoader that supports debugging the dataset without saving capability (e.g. for val/eval)."""
#: The worker config
worker_config: WorkerConfig
#: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper`
dataset: Union[SavableDatasetWrapper[T], SavableDataset[T]]
id: int
_sample_idx: int = 0
def __init__(
self,
dataset: SavableDataset[T],
gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER,
gc_freeze_at_start: bool = True,
prefetch_factor: int = 2,
):
"""
Create the dataloader supporting saving and restoring the state.
Args:
dataset: The dataset to load.
gc_collect_every_n_steps: The number of steps after which the garbage collector is
called. As we're usually handling large (but few) tensors here, and the python
garbage collection is already full of objects just by importing, this can improve
the memory footprint quite a lot, and may even be necessary to avoid memory
overflow.
gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker
processes. This improves the garbage collection performance by a lot.
In rare cases, this may cause issues and can be disabled. Keep enabled if you
experience no issues.
"""
self.worker_config = dataset.worker_config
self.id = SavableDataLoader.next_id()
if gc_collect_every_n_steps > 0:
dataset = GcDataset(
dataset,
worker_config=self.worker_config,
every_n_iter=gc_collect_every_n_steps,
freeze=gc_freeze_at_start,
)
dataset = SimpleSavableDatasetWrapper(dataset, worker_config=self.worker_config)
self._worker_sample_counters = [0] * max(self.worker_config.num_workers, 1)
kwargs = {}
if self.worker_config.num_workers > 0:
# These must not be specified for num_workers =0
kwargs["persistent_workers"] = True
kwargs["prefetch_factor"] = prefetch_factor
seed_per_worker = [
self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers)
]
gc.collect() # This ensures that we don't include any old worker refs in the newly forked worker processes
super().__init__(
dataset,
batch_size=None,
shuffle=False,
num_workers=self.worker_config.num_workers,
pin_memory=True,
worker_init_fn=partial(_init_worker, seed_per_worker),
**kwargs,
)
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "BasicDataLoader.__init__",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"config": self.config(),
}
)
def __iter__(self):
outerself = self
class InnerIterator:
"""Internal class which keeps the iterator alive across multiple `iter()` calls.
If the inner iterator is exhausted, will also exhaust and a new instance is needed.
Also saves the last sample index and the next worker id.
"""
iter_idx: int = 0
id: int
def __init__(self, iterator):
self._iterator = iterator
self.id = SavableDataLoader.next_id()
if outerself.worker_config.should_log(level=1):
outerself.worker_config.worker_log(
{
"t": "BasicDataLoader.iter",
"r": outerself.worker_config.rank,
"w": None,
"id": outerself.id,
"iter_id": self.id,
}
)
def __iter__(self):
return self
def __next__(self):
try:
worker_id, sample_idx, sample = next(self._iterator)
# If the next sample will be from the first worker, we can safely resume
self.next_worker_id = (worker_id + 1) % max(outerself.num_workers, 1)
if outerself.worker_config.should_log(level=1):
keys = default_get_keys(sample)
outerself.worker_config.worker_log(
{
**{
"t": "BasicDataLoader.yield",
"r": outerself.worker_config.rank,
"w": None,
"id": outerself.id,
"iter_id": self.id,
"worker_id": worker_id,
"worker_idx": sample_idx,
"idx": self.iter_idx,
"iter_idx": self.iter_idx,
"global_idx": outerself._sample_idx,
},
**({} if keys is None else {"keys": keys}),
}
)
outerself._sample_idx += 1
self.iter_idx += 1
return sample
except StopIteration:
self.next_worker_id = 0
if outerself.worker_config.should_log(level=1):
outerself.worker_config.worker_log(
{
"t": "BasicDataLoader.StopIteration",
"r": outerself.worker_config.rank,
"w": None,
"id": outerself.id,
"iter_id": self.id,
}
)
raise
return InnerIterator(super().__iter__())
def config(self):
"""Get the configuration, which defines the dataset. Useful in conjunction with `save_state`
and `restore_state` to match the configuration as well."""
return {
"type": type(self).__qualname__,
"num_workers": self.num_workers,
"persistent_workers": self.persistent_workers,
"pin_memory": self.pin_memory,
"prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor,
"dataset": self.dataset.config(),
}
def can_restore_sample(self) -> bool:
return self.dataset.can_restore_sample()
def restore_sample(self, sample_key: Tuple[Union[str, int, tuple], ...]) -> T:
"""Restores a sample from a key. This is useful to debug the dataset."""
return self.dataset.restore_sample(sample_key)
def _sample_str(self, sample):
"""Returns a human readable debug string for a single sample, also uniquely identifying it."""
import dataclasses
import hashlib
if isinstance(sample, torch.Tensor):
return f"Tensor(shape={sample.shape}, dtype={sample.dtype}, sha256={hashlib.sha256(sample.detach().cpu().numpy().tobytes()).hexdigest()!r})"
elif isinstance(sample, np.ndarray):
return f"ndarray(shape={sample.shape}, dtype={sample.dtype}, sha256={hashlib.sha256(sample.tobytes()).hexdigest()!r})"
elif isinstance(sample, (str, bytes)):
if len(sample) > 100:
return f"{sample[:100]!r}..."
return repr(sample)
elif isinstance(sample, (int, float)):
return repr(sample)
elif isinstance(sample, dict):
innerstr = ", ".join(f"{k}={self._sample_str(v)}, " for k, v in sample.items())
return f"dict({innerstr})"
elif isinstance(sample, (list, tuple)):
innerstr = ", ".join(self._sample_str(v) for v in sample)
return f"{type(sample)}({innerstr})"
elif dataclasses.is_dataclass(sample):
innerstr = ", ".join(
f"{field.name}={self._sample_str(getattr(sample, field.name))}, "
for field in dataclasses.fields(sample)
)
return f"{type(sample).__name__}({innerstr})"
else:
return repr(sample)