Source code for megatron.energon.flavors.webdataset.base_webdataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, TypeVar

import numpy as np

from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory, SavableDataset
from megatron.energon.flavors.webdataset.error_handler import ErrorHandler
from megatron.energon.flavors.webdataset.metadata import WebdatasetMeta
from megatron.energon.flavors.webdataset.prepare import WebdatasetPreparator
from megatron.energon.flavors.webdataset.sample_loader import (
    RawSampleData,
    WebdatasetSampleLoaderDataset,
)
from megatron.energon.flavors.webdataset.sharder import Sharder
from megatron.energon.flavors.webdataset.structs import FilteredSample, ShardInfo, reraise_exception
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.map_dataset import MapDataset

T_sample = TypeVar("T_sample", covariant=True)
T = TypeVar("T", covariant=True)

logger = logging.getLogger(__name__)


[docs] class BaseWebdatasetFactory( BaseCoreDatasetFactory[T_sample], WebdatasetPreparator, Sharder, ErrorHandler, Generic[T_sample], ABC, ): """ Base class for all webdataset sample loader factories. Applies proper sharding across workers. """ path: EPath training: bool worker_config: WorkerConfig shards: List[ShardInfo] rank_shards: List[List[Sequence[ShardInfo]]] def __init__( self, path: EPath, *, split_part: str, training: bool, worker_config: WorkerConfig, shuffle_over_epochs: Optional[int] = 1, parallel_shard_iters: Optional[int] = None, max_samples_per_sequence: Optional[int] = None, info_config: str = ".info.yaml", split_config: str = "split.yaml", part_filter: Optional[Callable[[str], bool]] = None, handler: Callable[[Exception, Optional[str]], None] = reraise_exception, ): """ Base factory for the webdataset sample loader. Args: path: Path to the dataset. split_part: Which part to load (e.g. 'train', 'val', 'test'). training: If true, apply shuffling and loop the dataset. worker_config: Configuration for the workers. shuffle_over_epochs: Only effective if training=True. How many epochs to shuffle over if training. If = 1, every sample is seen exactly once per epoch. If > 1, samples (or rather shard slices) are shuffled within this number of epochs (i.e. randomly selected without replacement). If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices are drawn with replacement). parallel_shard_iters: Number of parallel opened shards per worker, shuffling between. max_samples_per_sequence: Maximum number of samples per sequence (=how many samples will be sequentially iterated). info_config: Config file to use for sample metadata. split_config: Config file to use for shard split definitions. part_filter: (internal) Function for filtering tar files by dict keys handler: Exception handler. Args: (exception, key). """ assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__" wds_meta = WebdatasetMeta.from_config( path=path, split_part=split_part, info_config=info_config, split_config=split_config ) self.path = path self.paths = [path] self.shards = wds_meta.shards self.sample_excludes = wds_meta.sample_excludes self.split_part_files = wds_meta.split_part_files self.training = training self.worker_config = worker_config self.shuffle_over_epochs = shuffle_over_epochs self.parallel_shard_iters = parallel_shard_iters self.max_samples_per_sequence = max_samples_per_sequence self.part_filter = part_filter self.handler = handler def __len__(self) -> int: return sum(shard.count for shard in self.shards)
[docs] def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]: from megatron.energon.flavors.webdataset.itar_reader import ShardInfosITarReader if self.parallel_shard_iters is None: if self.training: # 16 seems to be a good choice since we don't want too many file handles open parallel_shard_iters = 16 else: parallel_shard_iters = 1 else: parallel_shard_iters = self.parallel_shard_iters workers_sample_slice_offsets = self.shard_workers( self.shards, worker_config=self.worker_config, max_samples_per_sequence=self.max_samples_per_sequence, rotation_offset=worker_rotation_offset, ) _print_shard_slices(self.worker_config, self.shards, workers_sample_slice_offsets) itar_reader = ShardInfosITarReader( self.path, self.shards, part_filter=self.part_filter, sample_filter=self.sample_filter, itar_cache_size=parallel_shard_iters, ) dataset = WebdatasetSampleLoaderDataset( join_readers=[itar_reader], workers_sample_slice_offsets=workers_sample_slice_offsets, worker_config=self.worker_config, shuffle_over_epochs=self.shuffle_over_epochs if self.training else None, parallel_slice_iters=parallel_shard_iters, handler=self.sample_error_handler, ) return MapDataset( dataset, self._load_sample_raw, error_handler=self.error_handler, stateless_map_fn=True, map_fn_config=self.config, worker_config=self.worker_config, )
[docs] def sample_filter(self, key: str) -> bool: return key not in self.sample_excludes
def _load_sample_raw(self, raw_sample: RawSampleData) -> T_sample: # Just a wrapper for the inner tuple. Tuple should be of length 1. assert len(raw_sample.data) == 1 and raw_sample.data[0] is not None return self.load_sample(raw_sample.data[0])
[docs] @abstractmethod def load_sample(self, raw_data: FilteredSample) -> T_sample: """Loads the sample from the dataset.""" ...
[docs] def config(self) -> Dict[str, Any]: return dict( type=type(self).__qualname__, training=self.training, _path=str(self.path), shards=[ dict( name=shard.name, count=shard.count, _path=str(shard.path), ) for shard in self.shards ], sample_excludes=list(self.sample_excludes), shuffle_over_epochs=self.shuffle_over_epochs, parallel_shard_iters=self.parallel_shard_iters, max_samples_per_sequence=self.max_samples_per_sequence, )
def __str__(self): return f"{type(self).__name__}(path={self.path})"
def _print_shard_slices( worker_config: WorkerConfig, shards: List[ShardInfo], slice_offsets: Sequence[Sequence[int]] ): shard_starts = np.cumsum([0] + [shard.count for shard in shards]) def shard_range_info(start: int, end: int) -> str: start_shard_idx = np.searchsorted(shard_starts, start, side="right") - 1 end_shard_idx = np.searchsorted(shard_starts, end, side="left") - 1 if start_shard_idx == end_shard_idx: shard = shards[start_shard_idx] if start - shard_starts[start_shard_idx] == 0: start_str = "(start)" else: start_str = "" if end - shard_starts[start_shard_idx] == shard.count: end_str = "(end)" else: end_str = "" return f"{shard.name}[{start - shard_starts[start_shard_idx]}{start_str}, {end - shard_starts[start_shard_idx]}{end_str}]" else: start_shard = shards[start_shard_idx] end_shard = shards[end_shard_idx] if start - shard_starts[start_shard_idx] == 0: start_str = "(start)" else: start_str = "" if end - shard_starts[end_shard_idx] == end_shard.count: end_str = "(end)" else: end_str = "" return f"{start_shard.name}[{start - shard_starts[start_shard_idx]}{start_str},]-{end_shard.name}[,{end - shard_starts[end_shard_idx]}{end_str}]" for worker_idx, sample_slice_offsets in enumerate(slice_offsets): start_idx = sample_slice_offsets[0] end_idx = sample_slice_offsets[-1] if len(sample_slice_offsets) > 6: offset_str = f"{', '.join(str(o) for o in sample_slice_offsets[:3])} ...<{len(sample_slice_offsets) - 6}> {', '.join(str(o) for o in sample_slice_offsets[-3:])}" else: offset_str = ", ".join(str(o) for o in sample_slice_offsets) if len(sample_slice_offsets) > 6: slices_str = ( ", ".join( shard_range_info(start, end) for start, end in zip(sample_slice_offsets[:3], sample_slice_offsets[1:4]) ) + f" ...<{len(sample_slice_offsets) - 6}> " + ", ".join( shard_range_info(start, end) for start, end in zip(sample_slice_offsets[-4:-1], sample_slice_offsets[-3:]) ) ) else: slices_str = ", ".join( shard_range_info(start, end) for start, end in zip(sample_slice_offsets[:-1], sample_slice_offsets[1:]) ) print( f"rank={worker_config.rank}, worker={worker_idx}: sample_range=[{start_idx}, {end_idx}] in {len(sample_slice_offsets) - 1} slices, " f"sum(count)={end_idx - start_idx}: indexes=[{offset_str}] slices=[{slices_str}]" )