Source code for megatron.energon.flavors.webdataset.base_webdataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, TypeVar

import numpy as np

from megatron.energon.cache import FileStore
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import (
    BaseCoreDatasetFactory,
    SavableDataset,
    legacy_handler,
)
from megatron.energon.flavors.webdataset.error_handler import ErrorHandler
from megatron.energon.flavors.webdataset.metadata import WebdatasetMeta
from megatron.energon.flavors.webdataset.prepare import WebdatasetPreparator
from megatron.energon.flavors.webdataset.sample_loader import (
    RawSampleData,
    WebdatasetSampleLoaderDataset,
)
from megatron.energon.flavors.webdataset.sharder import Sharder
from megatron.energon.flavors.webdataset.structs import FilteredSample, ShardInfo, reraise_exception
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.map_dataset import MapDataset

T_sample = TypeVar("T_sample", covariant=True)
T = TypeVar("T", covariant=True)

logger = logging.getLogger(__name__)


[docs] class BaseWebdatasetFactory( BaseCoreDatasetFactory[T_sample], WebdatasetPreparator, Sharder, ErrorHandler, Generic[T_sample], ABC, ): """ Base class for all webdataset sample loader factories. Applies proper sharding across workers. """ path: EPath training: bool worker_config: WorkerConfig shards: List[ShardInfo] rank_shards: List[List[Sequence[ShardInfo]]]
[docs] def __init__( self, path: EPath, *, split_part: str, training: bool, worker_config: WorkerConfig, shuffle_over_epochs: Optional[int] = 1, parallel_shard_iters: Optional[int] = None, max_samples_per_sequence: Optional[int] = None, split_config: str = "split.yaml", part_filter: Optional[Callable[[str], bool]] = None, handler: Callable[ [Exception, Optional[str], Optional[list[SourceInfo]]], None ] = reraise_exception, ): """ Base factory for the webdataset sample loader. Args: path: Path to the dataset. split_part: Which part to load (e.g. 'train', 'val', 'test'). training: If true, apply shuffling and loop the dataset. worker_config: Configuration for the workers. shuffle_over_epochs: Only effective if training=True. How many epochs to shuffle over if training. If = 1, every sample is seen exactly once per epoch. If > 1, samples (or rather shard slices) are shuffled within this number of epochs (i.e. randomly selected without replacement). If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices are drawn with replacement). parallel_shard_iters: Number of parallel opened shards per worker, shuffling between. max_samples_per_sequence: Maximum number of samples per sequence (=how many samples will be sequentially iterated). split_config: Config file to use for shard split definitions. part_filter: (internal) Function for filtering tar files by dict keys handler: Exception handler. Args: (exception, key). """ assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__" wds_meta = WebdatasetMeta.from_config( path=path, split_part=split_part, split_config=split_config ) self.path = path self.paths = [path] self.shards = wds_meta.shards self.sample_excludes = wds_meta.sample_excludes self.split_part_files = wds_meta.split_part_files self.training = training self.worker_config = worker_config self.shuffle_over_epochs = shuffle_over_epochs self.parallel_shard_iters = parallel_shard_iters self.max_samples_per_sequence = max_samples_per_sequence self.part_filter = part_filter self.handler = legacy_handler(handler)
def __len__(self) -> int: return sum(shard.count for shard in self.shards)
[docs] def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]: from megatron.energon.flavors.webdataset.itar_reader import ShardInfosITarReader if self.parallel_shard_iters is None: if self.training: # 16 seems to be a good choice since we don't want too many file handles open parallel_shard_iters = 16 else: parallel_shard_iters = 1 else: parallel_shard_iters = self.parallel_shard_iters workers_sample_slice_offsets = self.shard_workers( self.shards, worker_config=self.worker_config, max_samples_per_sequence=self.max_samples_per_sequence, rotation_offset=worker_rotation_offset, ) _print_shard_slices(self.worker_config, self.shards, workers_sample_slice_offsets) itar_reader = ShardInfosITarReader( self.path, self.shards, part_filter=self.part_filter, sample_filter=self.sample_filter, itar_cache_size=parallel_shard_iters, ) dataset = WebdatasetSampleLoaderDataset( join_readers=[itar_reader], workers_sample_slice_offsets=workers_sample_slice_offsets, worker_config=self.worker_config, shuffle_over_epochs=self.shuffle_over_epochs if self.training else None, parallel_slice_iters=parallel_shard_iters, ) return MapDataset( dataset, self._load_sample_raw, error_handler=self.error_handler, stateless_map_fn=True, map_fn_config=self.config, worker_config=self.worker_config, )
[docs] def as_file_store(self) -> "FileStore": from megatron.energon.cache.file_store import WebdatasetFileStore return WebdatasetFileStore(self.path)
[docs] def sample_filter(self, key: str) -> bool: return key not in self.sample_excludes
def _load_sample_raw(self, raw_sample: RawSampleData) -> T_sample: # Just a wrapper for the inner tuple. Tuple should be of length 1. assert len(raw_sample.data) == 1 and raw_sample.data[0] is not None return self.load_sample(raw_sample.data[0])
[docs] @abstractmethod def load_sample(self, raw_data: FilteredSample) -> T_sample: """Loads the sample from the dataset.""" ...
[docs] def config(self) -> Dict[str, Any]: return dict( type=type(self).__qualname__, training=self.training, _path=str(self.path), shards=[ dict( name=shard.name, count=shard.count, _path=str(shard.path), ) for shard in self.shards ], sample_excludes=list(self.sample_excludes), shuffle_over_epochs=self.shuffle_over_epochs, parallel_shard_iters=self.parallel_shard_iters, max_samples_per_sequence=self.max_samples_per_sequence, )
def __str__(self): return f"{type(self).__name__}(path={self.path})"
def _print_shard_slices( worker_config: WorkerConfig, shards: List[ShardInfo], slice_offsets: Sequence[Sequence[int]] ): shard_starts = np.cumsum([0] + [shard.count for shard in shards]) def shard_range_info(start: int, end: int) -> str: start_shard_idx = np.searchsorted(shard_starts, start, side="right") - 1 end_shard_idx = np.searchsorted(shard_starts, end, side="left") - 1 if start_shard_idx == end_shard_idx: shard = shards[start_shard_idx] if start - shard_starts[start_shard_idx] == 0: start_str = "(start)" else: start_str = "" if end - shard_starts[start_shard_idx] == shard.count: end_str = "(end)" else: end_str = "" return f"{shard.name}[{start - shard_starts[start_shard_idx]}{start_str}, {end - shard_starts[start_shard_idx]}{end_str}]" else: start_shard = shards[start_shard_idx] end_shard = shards[end_shard_idx] if start - shard_starts[start_shard_idx] == 0: start_str = "(start)" else: start_str = "" if end - shard_starts[end_shard_idx] == end_shard.count: end_str = "(end)" else: end_str = "" return f"{start_shard.name}[{start - shard_starts[start_shard_idx]}{start_str},]-{end_shard.name}[,{end - shard_starts[end_shard_idx]}{end_str}]" for worker_idx, sample_slice_offsets in enumerate(slice_offsets): start_idx = sample_slice_offsets[0] end_idx = sample_slice_offsets[-1] if len(sample_slice_offsets) > 6: offset_str = f"{', '.join(str(o) for o in sample_slice_offsets[:3])} ...<{len(sample_slice_offsets) - 6}> {', '.join(str(o) for o in sample_slice_offsets[-3:])}" else: offset_str = ", ".join(str(o) for o in sample_slice_offsets) if len(sample_slice_offsets) > 6: slices_str = ( ", ".join( shard_range_info(start, end) for start, end in zip(sample_slice_offsets[:3], sample_slice_offsets[1:4]) ) + f" ...<{len(sample_slice_offsets) - 6}> " + ", ".join( shard_range_info(start, end) for start, end in zip(sample_slice_offsets[-4:-1], sample_slice_offsets[-3:]) ) ) else: slices_str = ", ".join( shard_range_info(start, end) for start, end in zip(sample_slice_offsets[:-1], sample_slice_offsets[1:]) ) print( f"rank={worker_config.rank}, worker={worker_idx}: sample_range=[{start_idx}, {end_idx}] in {len(sample_slice_offsets) - 1} slices, " f"sum(count)={end_idx - start_idx}: indexes=[{offset_str}] slices=[{slices_str}]" )