Source code for megatron.energon.flavors.webdataset.joined_webdataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    List,
    Mapping,
    Optional,
    Sequence,
    Type,
    TypeVar,
    Union,
)

from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import (
    BaseCoreDatasetFactory,
    Sample,
    SavableDataset,
    set_sample_restore_key,
)
from megatron.energon.flavors.webdataset.base_webdataset import BaseWebdatasetFactory
from megatron.energon.flavors.webdataset.error_handler import ErrorHandler
from megatron.energon.flavors.webdataset.indexing import JoinIndexReader
from megatron.energon.flavors.webdataset.itar_reader import JoinIndexFileITarReader
from megatron.energon.flavors.webdataset.sample_loader import (
    RawSampleData,
    WebdatasetSampleLoaderDataset,
)
from megatron.energon.flavors.webdataset.sharder import Sharder
from megatron.energon.flavors.webdataset.structs import ShardInfo, reraise_exception
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.map_dataset import MapDataset

T_sample = TypeVar("T_sample", covariant=True)


[docs] class JoinedWebdatasetFactory( BaseCoreDatasetFactory[T_sample], Sharder, ErrorHandler[T_sample], Generic[T_sample], ABC ): """ Base class for all webdataset loaders. Applies proper sharding across workers. Can join multiple datasets. """ training: bool worker_config: WorkerConfig shuffle_over_epochs: Optional[int] = 1 parallel_shard_iters: Optional[int] max_samples_per_sequence: Optional[int] join_index: EPath handler: Callable[[Exception, Optional[str]], None] shards: List[Sequence[ShardInfo]] part_datasets: SavableDataset[T_sample] inner_datasets: List[BaseWebdatasetFactory] inner_dataset_keys: Optional[List[str]] _sample_joiner: Callable[..., T_sample] def __init__( self, inner_datasets: Union[Sequence[BaseWebdatasetFactory], Mapping[str, BaseWebdatasetFactory]], *, training: bool, worker_config: WorkerConfig, shuffle_over_epochs: Optional[int] = 1, parallel_shard_iters: Optional[int] = None, max_samples_per_sequence: Optional[int] = None, join_index: EPath, joiner: Union[Type[T_sample], Callable[..., T_sample]], handler: Callable[[Exception, Optional[str]], None] = reraise_exception, ): """ Constructs the loader for a joined webdataset. The samples from the inner datasets are joined into a single sample using the joiner function. Args: inner_dataset: The inner datasets. Must be loaded internally with `_is_composed=True`. Either a list (*args for joiner) or a dict (**kwargs for joiner) of datasets, where the samples will be passed to the joiner function as *args or **kwargs. training: If true, apply shuffling and loop the dataset. worker_config: Configuration for the workers. shuffle_over_epochs: Only effective if training=True. How many epochs to shuffle over if training. If = 1, every sample is seen exactly once per epoch. If > 1, samples (or rather shard slices) are shuffled within this number of epochs (i.e. randomly selected without replacement). If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices are drawn with replacement). parallel_shard_iters: Number of parallel opened shards per worker, shuffling between. max_samples_per_sequence: Maximum number of samples per sequence (=how many samples will be sequentially iterated). join_index: Path to the join index file. Only required for join_method="left". joiner: Type of the joined samples or a method for joining the samples. handler: Exception handler. Args: (exception, key). """ self.__sample_type__ = joiner assert all(not hasattr(d, "dataset") for d in inner_datasets), ( "Inner dataset was not instantiated with _is_composed=True" ) if isinstance(joiner, type) and issubclass(joiner, Sample): joiner = joiner.from_joined else: assert callable(joiner), f"Joiner {joiner} must be a callable or a Sample subclass" if isinstance(inner_datasets, Mapping): inner_keys = list(inner_datasets.keys()) self.inner_dataset_keys = inner_keys # Wrap the joiner to pass the samples as kwargs self._sample_joiner = lambda *samples: joiner(**dict(zip(inner_keys, samples))) inner_datasets = list(inner_datasets.values()) else: assert isinstance(inner_datasets, Sequence) self._sample_joiner = joiner self.inner_dataset_keys = None self.join_index = join_index self.inner_datasets = inner_datasets self.shards = list(zip(*(dataset.shards for dataset in self.inner_datasets))) self.training = training self.worker_config = worker_config self.shuffle_over_epochs = shuffle_over_epochs self.parallel_shard_iters = parallel_shard_iters self.max_samples_per_sequence = max_samples_per_sequence self.handler = handler def __len__(self) -> int: return sum(shard.count for shard in self.inner_datasets[0].shards)
[docs] def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]: if self.parallel_shard_iters is None: if self.training: # 16 seems to be a good choice since we don't want too many file handles open parallel_shard_iters = 16 else: parallel_shard_iters = 1 else: parallel_shard_iters = self.parallel_shard_iters # Get join index, get size, distribute samples # Get samples for each worker on current rank assert self.join_index.is_file(), ( f"Join index {self.join_index} does not exist, did you prepare the metadataset? " "If you already prepared the metadataset, the join index might be outdated due to " "modifications to the inner datasets. In this case, you need to re-prepare the metadataset." ) with JoinIndexReader(self.join_index) as jir: total_samples = len(jir) workers_sample_slice_offsets = self.slice_workers( total_samples, worker_config=self.worker_config, max_samples_per_sequence=self.max_samples_per_sequence, rotation_offset=worker_rotation_offset, ) for worker_idx, sample_slice_offsets in enumerate(workers_sample_slice_offsets): start_idx = sample_slice_offsets[0] end_idx = sample_slice_offsets[-1] if len(sample_slice_offsets) > 6: offset_str = f"{', '.join(str(o) for o in sample_slice_offsets[:3])} ...<{len(sample_slice_offsets) - 6}> {', '.join(str(o) for o in sample_slice_offsets[-3:])}" else: offset_str = ", ".join(str(o) for o in sample_slice_offsets) print( f"rank={self.worker_config.rank}, worker={worker_idx}: sample_range=[{start_idx}, {end_idx}) in {len(sample_slice_offsets) - 1} slices, " f"sum(count)={end_idx - start_idx}: [{offset_str}]" ) itar_readers = [ JoinIndexFileITarReader( index_file=self.join_index, column=col_idx, tar_filenames=indexed_dataset.split_part_files, base_path=indexed_dataset.path, part_filter=indexed_dataset.part_filter, itar_cache_size=parallel_shard_iters, ) for col_idx, indexed_dataset in enumerate(self.inner_datasets) ] dataset = WebdatasetSampleLoaderDataset( join_readers=itar_readers, workers_sample_slice_offsets=workers_sample_slice_offsets, worker_config=self.worker_config, shuffle_over_epochs=self.shuffle_over_epochs if self.training else None, parallel_slice_iters=parallel_shard_iters, handler=self.sample_error_handler, ) return self._process_samples(dataset)
@property def paths(self) -> List[EPath]: return [dataset.path for dataset in self.inner_datasets] def _process_samples(self, dataset: SavableDataset[RawSampleData]) -> SavableDataset[T_sample]: """Internally loads the sample.""" return MapDataset( dataset, self.load_sample, error_handler=self.error_handler, stateless_map_fn=True, map_fn_config=self.config, worker_config=self.worker_config, )
[docs] def load_sample(self, samples: RawSampleData) -> T_sample: assert len(samples.data) > 0 and samples.data[0] is not None, "Always need primary sample" # First call the loaders of all inner datasets loaded_samples = tuple( None if sample is None else dataset.load_sample(sample) for dataset, sample in zip(self.inner_datasets, samples.data) ) # Then combine the loaded smaples into the final type return set_sample_restore_key( self._sample_joiner(*loaded_samples), *samples.__restore_key__, src=self, fail_otherwise=True, )
[docs] def config(self) -> Dict[str, Any]: return dict( type=type(self).__qualname__, joined_datasets=[dataset.config() for dataset in self.inner_datasets], training=self.training, shuffle_over_epochs=self.shuffle_over_epochs, parallel_shard_iters=self.parallel_shard_iters, max_samples_per_sequence=self.max_samples_per_sequence, )
def __str__(self): return f"{type(self).__name__}(paths={self.paths})"