Source code for megatron.energon.flavors.jsonl.crude_jsonl_dataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import json
import logging
from typing import Any, Callable, Dict, Optional

from megatron.energon.cache import FileStore
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import (
    BaseCoreDatasetFactory,
    SavableDataset,
    legacy_handler,
)
from megatron.energon.flavors.crude import CrudeSample
from megatron.energon.flavors.jsonl.ijsonl import IJsonlIndexReader
from megatron.energon.flavors.jsonl.jsonl_prepare import JsonlPreparator
from megatron.energon.flavors.webdataset.base_webdataset import _print_shard_slices
from megatron.energon.flavors.webdataset.error_handler import ErrorHandler
from megatron.energon.flavors.webdataset.sample_loader import (
    RawSampleData,
    WebdatasetSampleLoaderDataset,
)
from megatron.energon.flavors.webdataset.sharder import Sharder
from megatron.energon.flavors.webdataset.structs import (
    DatasetSubset,
    FilteredSample,
    ShardInfo,
    reraise_exception,
)
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.map_dataset import MapDataset

logger = logging.getLogger(__name__)


[docs] class CrudeJsonlDatasetFactory( BaseCoreDatasetFactory[CrudeSample], JsonlPreparator, Sharder, ErrorHandler, ): """ Factory class for creating a crude dataset from JSONL (JSON Lines) files. This factory creates datasets from JSONL files where each line contains a JSON object. The samples are returned as CrudeSample objects (dictionary-like) containing the raw JSON data. """ __sample_type__ = CrudeSample path: EPath training: bool worker_config: WorkerConfig
[docs] def __init__( self, path: EPath, *, training: bool, worker_config: WorkerConfig, shuffle_over_epochs: Optional[int] = 1, parallel_shard_iters: Optional[int] = None, max_samples_per_sequence: Optional[int] = None, subset: Optional[DatasetSubset] = None, part_filter: Optional[Callable[[str], bool]] = None, handler: Callable[ [Exception, Optional[str], Optional[list[SourceInfo]]], None ] = reraise_exception, ): """ Factory for a jsonl file as a crude dataset. Args: path: Path to the jsonl file. training: If true, apply shuffling and loop the dataset. worker_config: Configuration for the workers. shuffle_over_epochs: Only effective if training=True. How many epochs to shuffle over if training. If = 1, every sample is seen exactly once per epoch. If > 1, samples (or rather shard slices) are shuffled within this number of epochs (i.e. randomly selected without replacement). If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices are drawn with replacement). parallel_shard_iters: Number of parallel opened shards per worker, shuffling between. max_samples_per_sequence: Maximum number of samples per sequence (=how many samples will be sequentially iterated). subset: If specified, the dataset will be subsetted. part_filter: (internal) Function for filtering tar files by dict keys handler: Exception handler. Args: (exception, key). """ assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__" self.path = path self.paths = [path] self.training = training self.worker_config = worker_config self.shuffle_over_epochs = shuffle_over_epochs self.parallel_shard_iters = parallel_shard_iters self.max_samples_per_sequence = max_samples_per_sequence self.subset = subset self.part_filter = part_filter self.handler = legacy_handler(handler) if part_filter is None or part_filter("json"): self._len = IJsonlIndexReader.count_samples(path) else: self._len = 0 assert self.path.size() == IJsonlIndexReader.size(path), ( "The index of the jsonl file does not match the file. Regenerate the index." )
def __len__(self) -> int: return self._len
[docs] def build(self, worker_rotation_offset: int = 0) -> SavableDataset[CrudeSample]: from megatron.energon.flavors.jsonl.ijsonl_reader import IJsonlReader if self.parallel_shard_iters is None: if self.training: # 16 seems to be a good choice since we don't want too many file handles open parallel_shard_iters = 16 else: parallel_shard_iters = 1 else: parallel_shard_iters = self.parallel_shard_iters virtual_shards = [ ShardInfo( name=self.path.name, path=self.path, count=self._len, ) ] workers_sample_slice_offsets = self.shard_workers( virtual_shards, worker_config=self.worker_config, max_samples_per_sequence=self.max_samples_per_sequence, rotation_offset=worker_rotation_offset, subset=self.subset, ) _print_shard_slices(self.worker_config, virtual_shards, workers_sample_slice_offsets) itar_reader = IJsonlReader( self.path, index_cache_size=parallel_shard_iters, ) dataset = WebdatasetSampleLoaderDataset( join_readers=[itar_reader], workers_sample_slice_offsets=workers_sample_slice_offsets, worker_config=self.worker_config, shuffle_over_epochs=self.shuffle_over_epochs if self.training else None, parallel_slice_iters=parallel_shard_iters, ) return MapDataset( dataset, self._load_sample_raw, error_handler=self.error_handler, stateless_map_fn=True, map_fn_config=self.config, worker_config=self.worker_config, )
[docs] def as_file_store(self) -> "FileStore": from megatron.energon.cache.file_store import JsonlFileStore return JsonlFileStore(self.path)
def _load_sample(self, sample: FilteredSample) -> CrudeSample: return CrudeSample(sample) def _load_sample_raw(self, raw_sample: RawSampleData) -> CrudeSample: # Just a wrapper for the inner tuple. Tuple should be of length 1. assert len(raw_sample.data) == 1 and raw_sample.data[0] is not None return self._load_sample(raw_sample.data[0])
[docs] def config(self) -> Dict[str, Any]: return dict( type=type(self).__qualname__, training=self.training, _path=str(self.path), jsonl_filename=self.path.name, count=self._len, shuffle_over_epochs=self.shuffle_over_epochs, parallel_shard_iters=self.parallel_shard_iters, max_samples_per_sequence=self.max_samples_per_sequence, subset=self.subset.config() if self.subset is not None else None, )
def __str__(self): return f"{type(self).__name__}(path={self.path})"
[docs] class DefaultCrudeJsonlDatasetFactory(CrudeJsonlDatasetFactory): """ Adds subflavors to the sample and loads the json. """ def __init__(self, path: EPath, *, subflavors: Optional[Dict[str, Any]] = None, **kwargs): if "decoder" in kwargs: del kwargs["decoder"] super().__init__(path, **kwargs) self.subflavors = subflavors def _load_sample(self, sample: FilteredSample) -> CrudeSample: sample["__subflavors__"] = self.subflavors # Instead of using a decoder, we just load the json here, as we know it's json. sample["json"] = json.loads(sample["json"]) return super()._load_sample(sample)
[docs] def config(self) -> Dict[str, Any]: return dict( **super().config(), subflavors=self.subflavors, )