# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, TypeVar
import numpy as np
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory, SavableDataset
from megatron.energon.flavors.webdataset.error_handler import ErrorHandler
from megatron.energon.flavors.webdataset.metadata import WebdatasetMeta
from megatron.energon.flavors.webdataset.prepare import WebdatasetPreparator
from megatron.energon.flavors.webdataset.sample_loader import (
RawSampleData,
WebdatasetSampleLoaderDataset,
)
from megatron.energon.flavors.webdataset.sharder import Sharder
from megatron.energon.flavors.webdataset.structs import FilteredSample, ShardInfo, reraise_exception
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.map_dataset import MapDataset
T_sample = TypeVar("T_sample", covariant=True)
T = TypeVar("T", covariant=True)
logger = logging.getLogger(__name__)
[docs]
class BaseWebdatasetFactory(
BaseCoreDatasetFactory[T_sample],
WebdatasetPreparator,
Sharder,
ErrorHandler,
Generic[T_sample],
ABC,
):
"""
Base class for all webdataset sample loader factories. Applies proper sharding across workers.
"""
path: EPath
training: bool
worker_config: WorkerConfig
shards: List[ShardInfo]
rank_shards: List[List[Sequence[ShardInfo]]]
def __init__(
self,
path: EPath,
*,
split_part: str,
training: bool,
worker_config: WorkerConfig,
shuffle_over_epochs: Optional[int] = 1,
parallel_shard_iters: Optional[int] = None,
max_samples_per_sequence: Optional[int] = None,
info_config: str = ".info.yaml",
split_config: str = "split.yaml",
part_filter: Optional[Callable[[str], bool]] = None,
handler: Callable[[Exception, Optional[str]], None] = reraise_exception,
):
"""
Base factory for the webdataset sample loader.
Args:
path: Path to the dataset.
split_part: Which part to load (e.g. 'train', 'val', 'test').
training: If true, apply shuffling and loop the dataset.
worker_config: Configuration for the workers.
shuffle_over_epochs: Only effective if training=True.
How many epochs to shuffle over if training.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather shard slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices
are drawn with replacement).
parallel_shard_iters: Number of parallel opened shards per worker, shuffling between.
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequentially iterated).
info_config: Config file to use for sample metadata.
split_config: Config file to use for shard split definitions.
part_filter: (internal) Function for filtering tar files by dict keys
handler: Exception handler. Args: (exception, key).
"""
assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__"
wds_meta = WebdatasetMeta.from_config(
path=path, split_part=split_part, info_config=info_config, split_config=split_config
)
self.path = path
self.paths = [path]
self.shards = wds_meta.shards
self.sample_excludes = wds_meta.sample_excludes
self.split_part_files = wds_meta.split_part_files
self.training = training
self.worker_config = worker_config
self.shuffle_over_epochs = shuffle_over_epochs
self.parallel_shard_iters = parallel_shard_iters
self.max_samples_per_sequence = max_samples_per_sequence
self.part_filter = part_filter
self.handler = handler
def __len__(self) -> int:
return sum(shard.count for shard in self.shards)
[docs]
def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]:
from megatron.energon.flavors.webdataset.itar_reader import ShardInfosITarReader
if self.parallel_shard_iters is None:
if self.training:
# 16 seems to be a good choice since we don't want too many file handles open
parallel_shard_iters = 16
else:
parallel_shard_iters = 1
else:
parallel_shard_iters = self.parallel_shard_iters
workers_sample_slice_offsets = self.shard_workers(
self.shards,
worker_config=self.worker_config,
max_samples_per_sequence=self.max_samples_per_sequence,
rotation_offset=worker_rotation_offset,
)
_print_shard_slices(self.worker_config, self.shards, workers_sample_slice_offsets)
itar_reader = ShardInfosITarReader(
self.path,
self.shards,
part_filter=self.part_filter,
sample_filter=self.sample_filter,
itar_cache_size=parallel_shard_iters,
)
dataset = WebdatasetSampleLoaderDataset(
join_readers=[itar_reader],
workers_sample_slice_offsets=workers_sample_slice_offsets,
worker_config=self.worker_config,
shuffle_over_epochs=self.shuffle_over_epochs if self.training else None,
parallel_slice_iters=parallel_shard_iters,
handler=self.sample_error_handler,
)
return MapDataset(
dataset,
self._load_sample_raw,
error_handler=self.error_handler,
stateless_map_fn=True,
map_fn_config=self.config,
worker_config=self.worker_config,
)
[docs]
def sample_filter(self, key: str) -> bool:
return key not in self.sample_excludes
def _load_sample_raw(self, raw_sample: RawSampleData) -> T_sample:
# Just a wrapper for the inner tuple. Tuple should be of length 1.
assert len(raw_sample.data) == 1 and raw_sample.data[0] is not None
return self.load_sample(raw_sample.data[0])
[docs]
@abstractmethod
def load_sample(self, raw_data: FilteredSample) -> T_sample:
"""Loads the sample from the dataset."""
...
[docs]
def config(self) -> Dict[str, Any]:
return dict(
type=type(self).__qualname__,
training=self.training,
_path=str(self.path),
shards=[
dict(
name=shard.name,
count=shard.count,
_path=str(shard.path),
)
for shard in self.shards
],
sample_excludes=list(self.sample_excludes),
shuffle_over_epochs=self.shuffle_over_epochs,
parallel_shard_iters=self.parallel_shard_iters,
max_samples_per_sequence=self.max_samples_per_sequence,
)
def __str__(self):
return f"{type(self).__name__}(path={self.path})"
def _print_shard_slices(
worker_config: WorkerConfig, shards: List[ShardInfo], slice_offsets: Sequence[Sequence[int]]
):
shard_starts = np.cumsum([0] + [shard.count for shard in shards])
def shard_range_info(start: int, end: int) -> str:
start_shard_idx = np.searchsorted(shard_starts, start, side="right") - 1
end_shard_idx = np.searchsorted(shard_starts, end, side="left") - 1
if start_shard_idx == end_shard_idx:
shard = shards[start_shard_idx]
if start - shard_starts[start_shard_idx] == 0:
start_str = "(start)"
else:
start_str = ""
if end - shard_starts[start_shard_idx] == shard.count:
end_str = "(end)"
else:
end_str = ""
return f"{shard.name}[{start - shard_starts[start_shard_idx]}{start_str}, {end - shard_starts[start_shard_idx]}{end_str}]"
else:
start_shard = shards[start_shard_idx]
end_shard = shards[end_shard_idx]
if start - shard_starts[start_shard_idx] == 0:
start_str = "(start)"
else:
start_str = ""
if end - shard_starts[end_shard_idx] == end_shard.count:
end_str = "(end)"
else:
end_str = ""
return f"{start_shard.name}[{start - shard_starts[start_shard_idx]}{start_str},]-{end_shard.name}[,{end - shard_starts[end_shard_idx]}{end_str}]"
for worker_idx, sample_slice_offsets in enumerate(slice_offsets):
start_idx = sample_slice_offsets[0]
end_idx = sample_slice_offsets[-1]
if len(sample_slice_offsets) > 6:
offset_str = f"{', '.join(str(o) for o in sample_slice_offsets[:3])} ...<{len(sample_slice_offsets) - 6}> {', '.join(str(o) for o in sample_slice_offsets[-3:])}"
else:
offset_str = ", ".join(str(o) for o in sample_slice_offsets)
if len(sample_slice_offsets) > 6:
slices_str = (
", ".join(
shard_range_info(start, end)
for start, end in zip(sample_slice_offsets[:3], sample_slice_offsets[1:4])
)
+ f" ...<{len(sample_slice_offsets) - 6}> "
+ ", ".join(
shard_range_info(start, end)
for start, end in zip(sample_slice_offsets[-4:-1], sample_slice_offsets[-3:])
)
)
else:
slices_str = ", ".join(
shard_range_info(start, end)
for start, end in zip(sample_slice_offsets[:-1], sample_slice_offsets[1:])
)
print(
f"rank={worker_config.rank}, worker={worker_idx}: sample_range=[{start_idx}, {end_idx}] in {len(sample_slice_offsets) - 1} slices, "
f"sum(count)={end_idx - start_idx}: indexes=[{offset_str}] slices=[{slices_str}]"
)