megatron.energon

class megatron.energon.AugmentTaskEncoder(task_encoder)[source]

Bases: TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch], Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]

Augment a task encoder with additional functionality. By default, delegates everything to the original task encoder.

Parameters:

task_encoder (TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch])

__init__(task_encoder)[source]

Initialize the augmenting task encoder.

Parameters:

task_encoder (TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch]) – The delegate task encoder. All calls will by default be forwarded to this.

batch(samples)[source]

Move a batch to a device. May raise megatron.energon.SkipSample to skip a batch.

Parameters:

samples (List[T_encoded_sample])

Return type:

T_raw_batch

property decoder: SampleDecoder

The default decoder for webdataset samples.

encode_batch(batch_data)[source]

Encode a batch of samples. May raise megatron.energon.SkipSample to skip a batch. Alternatively, this can be a generator that yields (or ignores) new batches.

Parameters:

batch_data (T_raw_batch)

Return type:

T_batch

encode_sample(sample)[source]

Encode a single sample. May raise megatron.energon.SkipSample to skip a sample. Alternatively, this can be a generator that yields (or ignores) new samples. If this is defined, preencode_sample() and postencode_sample() must not be defined.

Parameters:

sample (T_sample)

Return type:

T_encoded_sample

class megatron.energon.BaseCoreDatasetFactory[source]

Bases: Generic[T_sample], ABC

Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for joining in a joined dataset.

abstractmethod as_file_store()[source]

Returns the dataset as a random access dataset.

Return type:

FileStore

abstractmethod build(worker_rotation_offset=0)[source]

Builds the dataset.

Parameters:

worker_rotation_offset (int)

Return type:

SavableDataset[T_sample]

paths: List[EPath]
subflavors: Dict[str, Any]
class megatron.energon.BaseWebdatasetFactory(path, *, split_part, training, worker_config, shuffle_over_epochs=1, parallel_shard_iters=None, max_samples_per_sequence=None, split_config='split.yaml', part_filter=None, handler=<function reraise_exception>)[source]

Bases: BaseCoreDatasetFactory[T_sample], WebdatasetPreparator, Sharder, ErrorHandler, Generic[T_sample], ABC

Base class for all webdataset sample loader factories. Applies proper sharding across workers.

Parameters:
  • path (EPath)

  • split_part (str)

  • training (bool)

  • worker_config (WorkerConfig)

  • shuffle_over_epochs (int | None)

  • parallel_shard_iters (int | None)

  • max_samples_per_sequence (int | None)

  • split_config (str)

  • part_filter (Callable[[str], bool] | None)

  • handler (Callable[[Exception, str | None, list[SourceInfo] | None], None])

__init__(path, *, split_part, training, worker_config, shuffle_over_epochs=1, parallel_shard_iters=None, max_samples_per_sequence=None, split_config='split.yaml', part_filter=None, handler=<function reraise_exception>)[source]

Base factory for the webdataset sample loader.

Parameters:
  • path (EPath) – Path to the dataset.

  • split_part (str) – Which part to load (e.g. ‘train’, ‘val’, ‘test’).

  • training (bool) – If true, apply shuffling and loop the dataset.

  • worker_config (WorkerConfig) – Configuration for the workers.

  • shuffle_over_epochs (int | None) – Only effective if training=True. How many epochs to shuffle over if training. If = 1, every sample is seen exactly once per epoch. If > 1, samples (or rather shard slices) are shuffled within this number of epochs (i.e. randomly selected without replacement). If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices are drawn with replacement).

  • parallel_shard_iters (int | None) – Number of parallel opened shards per worker, shuffling between.

  • max_samples_per_sequence (int | None) – Maximum number of samples per sequence (=how many samples will be sequentially iterated).

  • split_config (str) – Config file to use for shard split definitions.

  • part_filter (Callable[[str], bool] | None) – (internal) Function for filtering tar files by dict keys

  • handler (Callable[[Exception, str | None, list[SourceInfo] | None], None]) – Exception handler. Args: (exception, key).

as_file_store()[source]

Returns the dataset as a random access dataset.

Return type:

FileStore

build(worker_rotation_offset=0)[source]

Builds the dataset.

Parameters:

worker_rotation_offset (int)

Return type:

SavableDataset[T_sample]

config()[source]
Return type:

Dict[str, Any]

abstractmethod load_sample(raw_data)[source]

Loads the sample from the dataset.

Parameters:

raw_data (FilteredSample)

Return type:

T_sample

path: EPath
rank_shards: List[List[Sequence[ShardInfo]]]
sample_filter(key)[source]
Parameters:

key (str)

Return type:

bool

shards: List[ShardInfo]
training: bool
worker_config: WorkerConfig
class megatron.energon.Batch(*, __key__, __restore_key__, __subflavors__=None, __sources__=None)[source]

Bases: PinMemoryMixin, ExtendableDataclassMixin

Base class for a batch dataclass. Provides a default implementation for pinning memory. Additionally, it provides a future safe implementation for creating an instance from another batch Batch.derive_from.

Parameters:
  • __key__ (list[str])

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (list[Dict[str, Any] | None] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

classmethod derive_from(base_batch, **kwargs)[source]

Uses the base fields of Batch from base_batch (i.e. __key__, __restore_key__, __subflavors__, __sources__) and creates a new batch with the kwargs as fields. This is useful for creating new batches, while keeping the metadata of the base batch.

Use like:

.. code-block:: python
def encode_batch(batch: RawBatch) -> Batch:

return Batch.derive_from(batch, field1=batch.field1 + 1)

Parameters:
  • base_batch (Batch) – The base batch to copy the base fields / metadata from.

  • kwargs – The fields of the new batch.

Returns:

The new batch.

Return type:

T_batch

classmethod from_samples(samples, **kwargs)[source]

Creates a batch from samples to be batched. Tensors will be padded and stacked, other types will be put into lists. This is the default implementation for Batch.from_samples.

Parameters:
  • samples (Sequence[Sample]) – The samples to batch.

  • kwargs – Additional (overriding) fields of the batch.

Returns:

The constructed batch.

Return type:

T_batch

class megatron.energon.BatchDataset(dataset, batch_size, batcher, *, batcher_stateless=False, batcher_config=None, drop_last=False, error_handler=<function log_exception>, failure_tolerance=100, worker_config)[source]

Bases: BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]

This dataset wrapper transforms a dataset of samples into a dataset of batches.

Parameters:
  • dataset (SavableDataset[T_batch_sample])

  • batch_size (int)

  • batcher (Callable[[List[T_batch_sample]], T_batch])

  • batcher_stateless (bool)

  • batcher_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • drop_last (bool)

  • error_handler (Callable[[Exception, list[T_batch_sample], list[SourceInfo]], None])

  • failure_tolerance (int | None)

  • worker_config (WorkerConfig)

__init__(dataset, batch_size, batcher, *, batcher_stateless=False, batcher_config=None, drop_last=False, error_handler=<function log_exception>, failure_tolerance=100, worker_config)[source]

Construct a BatchDataset.

Parameters:
  • dataset (SavableDataset[T_batch_sample]) – The input dataset to wrap

  • batch_size (int) – The desired batch size. The last batch may be smaller.

  • batcher (Callable[[List[T_batch_sample]], T_batch]) – Function which combines separate samples into a single object. May raise megatron.energon.SkipSample to skip a sample.

  • batcher_stateless (bool) – If True, the batcher is stateless, thus samples can be stored/ restored.

  • batcher_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None) – Configuration for the batcher function. If callable, it should return the configuration. Defaults to None.

  • drop_last (bool) – If True, the last batch is dropped if it is smaller than the batch size.

  • error_handler (Callable[[Exception, List[T_batch_sample], List[SourceInfo]], None]) – Function which handles exceptions raised by the batcher. The default implementation logs the exception.

  • failure_tolerance (int | None) – The number of consecutive failures after which the dataset is considered broken.

  • worker_config (WorkerConfig) – Configuration for the workers.

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

batch_size: int
batcher: Callable[[List[T_batch_sample]], T_batch]
can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

drop_last: bool
error_handler: Callable[[Exception, list[T_batch_sample], list[SourceInfo]], None]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(restore_key)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

restore_key (Tuple[str | int | tuple, ...])

Return type:

T_batch

class megatron.energon.BlendDataset(*dataset_weights, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample]

This dataset wrapper blends multiple iterable datasets together give a weighting. The datasets may be infinite. This dataset is always infinite.

Parameters:
__init__(*dataset_weights, worker_config)[source]

Construct a BlendDataset.

Parameters:
  • dataset_weights (Tuple[SavableDataset[T_sample], float]) – Each argument should be a tuple of (dataset, weight) with a weight between 0 and 1. The output samples are sampled from the input datasets with the given probabilities.

  • worker_config (WorkerConfig) – Configuration for the workers.

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

exhausted: List[bool]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

weights: Tuple[float, ...]
class megatron.energon.CachePool[source]

Bases: ABC

A cache pool allows to load the needed data in the background and access it later. The most important example being FileStoreCachePool which caches data on a local SSD disk.

To request data, use get_lazy to get a Lazy object. Then, call Lazy.get() to get the data later on.

abstractmethod close()[source]

Close the cache pool.

Return type:

None

abstractmethod get(ds, fname, sample=None)[source]

Get the data for a given file and adds the source info to the sample.

Parameters:
Return type:

Any

abstractmethod get_lazy(ds, fname)[source]

Get a lazy reference to the data for a given file.

Parameters:
Return type:

Lazy

class megatron.energon.CaptioningSample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, image, caption)[source]

Bases: Sample

Sample type for image captioning.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • image (Tensor)

  • caption (str)

caption: str

The caption string

image: Tensor

The input image tensor in the shape (C, H, W)

class megatron.energon.CaptioningWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[CaptioningSample]

Parameters:

path (EPath)

class megatron.energon.ConcatDataset(*datasets, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

This dataset wrapper concatenates multiple iterable datasets together. The datasets must be finite, otherwise not all datasets can be sampled. This is only useful for validation / test datasets.

Parameters:
__init__(*datasets, worker_config)[source]

Construct a concatenated dataset.

Parameters:
config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.Cooker(cook, has_subflavors=None)[source]

Bases: Generic[T_sample]

A cooker transforms a crude sample (simple dict) into a specific sample type inheriting from Sample. The cook method performs the transformation, the other fields are used to select the samples which this cooker can transform. If no filters are provided, the cooker will transform any CrudeSample.

Parameters:
  • cook (Callable[[...], T_sample])

  • has_subflavors (dict | None)

cook: Callable[[...], T_sample]
has_subflavors: dict | None = None
is_match(crude_sample)[source]
Parameters:

crude_sample (CrudeSample)

Return type:

bool

property need_cache: bool
property need_primary: bool
class megatron.energon.CrudeSample[source]

Bases: dict

Generic sample type to be processed later.

class megatron.energon.CrudeWebdataset(path, *, subflavors=None, part_filter=<function CrudeWebdataset.<lambda>>, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[CrudeSample]

The CrudeWebdataset is used to load crude / raw samples and decode them in the user code using so-called cookers.

See the documentation under “Crude Data” for more information.

Parameters:
  • path (EPath)

  • subflavors (Dict[str, Any] | None)

  • part_filter (str | List[str] | Callable[[str], bool])

__init__(path, *, subflavors=None, part_filter=<function CrudeWebdataset.<lambda>>, **kwargs)[source]

Constructs a crude webdataset.

Parameters:
  • path (EPath) – Root path to the joined datasets.

  • subflavors (Dict[str, Any] | None) – Subflavors dictionary to set for all loaded samples.

  • part_filter (str | List[str] | Callable[[str], bool]) – Function for filtering tar files to load by dict keys.

  • **kwargs – Additional arguments to the BaseWebdataset constructor.

class megatron.energon.DatasetLoader(*, path, split_part=None, subflavors=None, shuffle_over_epochs_multiplier=1, dataset_config='dataset.yaml', split_config='split.yaml')[source]

Bases: DatasetLoaderInterface

Loads a dataset from a path.

Parameters:
  • path (str | EPath)

  • split_part (str | None)

  • subflavors (Dict[str, Any] | None)

  • shuffle_over_epochs_multiplier (int | None)

  • dataset_config (str)

  • split_config (str)

dataset_config: str
get_dataset(*, training, split_part=None, worker_config, subflavors=None, shuffle_over_epochs=1, split_config=None, dataset_config=None, **kwargs)[source]
Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (str | None) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration.

  • shuffle_buffer_size – Size of the sample shuffle buffer (before task encoding).

  • subflavors (Dict[str, Any] | None) – Subflavors to use, might be overridden by inner datasets.

  • shuffle_over_epochs (int | None) – Shuffle the dataset over this many epochs.

  • **kwargs – Additional arguments to the dataset constructor.

  • split_config (str | None)

  • dataset_config (str | None)

Returns:

The loaded dataset

Return type:

BaseCoreDatasetFactory

get_datasets(*, training, split_part, worker_config, subflavors=None, shuffle_over_epochs_multiplier=1, **kwargs)[source]

Calls megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset) for all innermost datasets and resolves their relative weights to absolute weights.

Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (Literal['train', 'val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • subflavors (Dict[str, Any] | None) – Set the default subflavors for all datasets.

  • shuffle_over_epochs_multiplier (int | None) – Multiply the inner datasets shuffle_over_epochs(_multiplier) by this factor. E.g. if the inner dataset has shuffle_over_epochs_multiplier=2 and this function has shuffle_over_epochs_multiplier=3, the inner dataset will be shuffled over 6 epochs. Shuffling over n epochs guarantees that each sample is seen exactly n times in n epochs of the inner dataset. Use -1 for shuffling over an infinite number of epochs (effectively, this will draw shard slices with replacement).

  • **kwargs – Additional arguments to the dataset constructor.

Returns:

The dataset blending mode and the instantiated core datasets with their weights/repetitions.

Return type:

LoadedDatasetList

path: str | EPath
post_initialize(mds_path=None)[source]

Called to finally initialize the dataset.

Parameters:

mds_path (EPath | None)

shuffle_over_epochs_multiplier: int | None
split_config: str
split_part: str | None
subflavors: Dict[str, Any] | None
class megatron.energon.DatasetLoaderInterface[source]

Bases: ABC

General interface for a dataset loader.

abstractmethod get_datasets(*, training, split_part, worker_config, subflavors=None, shuffle_over_epochs_multiplier=1, **kwargs)[source]

Calls megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset) for all innermost datasets and resolves their relative weights to absolute weights.

Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (Literal['train', 'val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • subflavors (Dict[str, Any] | None) – Set the default subflavors for all datasets.

  • shuffle_over_epochs_multiplier (int | None) – Multiply the inner datasets shuffle_over_epochs(_multiplier) by this factor. E.g. if the inner dataset has shuffle_over_epochs_multiplier=2 and this function has shuffle_over_epochs_multiplier=3, the inner dataset will be shuffled over 6 epochs. Shuffling over n epochs guarantees that each sample is seen exactly n times in n epochs of the inner dataset. Use -1 for shuffling over an infinite number of epochs (effectively, this will draw shard slices with replacement).

  • **kwargs – Additional arguments to the dataset constructor.

Returns:

The dataset blending mode and the instantiated core datasets with their weights/repetitions.

Return type:

LoadedDatasetList

abstractmethod post_initialize(mds_path=None)[source]

Called to finally initialize the dataset.

Parameters:

mds_path (EPath | None)

prepare(split_part=None)[source]

Prepares the loader by creating caches and other necessary structures on disk.

Parameters:

split_part (str | None) – Name of the split to load.

Returns:

List of paths to the cache paths. This is used for cleanup.

Return type:

Sequence[EPath]

class megatron.energon.DecodeFileStore(inner_reader, *, decoder)[source]

Bases: FileStore[Any]

Used to wrap a FileStore and decode the data on access.

Parameters:
__init__(inner_reader, *, decoder)[source]
Parameters:
  • inner_reader (FileStore[bytes]) – The FileStore to wrap.

  • decoder (FileStoreDecoder) – The decoder to apply to every item read from the FileStore.

get_path()[source]

Returns the path to the dataset.

Return type:

str

class megatron.energon.DefaultDecoderWebdatasetFactory(path, *, decoder=<megatron.energon.flavors.webdataset.sample_decoder.SampleDecoder object>, **kwargs)[source]

Bases: DefaultGenericWebdatasetFactory[T_sample], Generic[T_sample]

Extends the default webdataset loading with decoding of contained files, such as images, videos or nested containers.

Parameters:
__init__(path, *, decoder=<megatron.energon.flavors.webdataset.sample_decoder.SampleDecoder object>, **kwargs)[source]

Factory for the webdataset sample loader including the decoder.

Parameters:
  • path (EPath) – Path to the dataset (passed to parent)

  • decoder (SampleDecoder | None) – If provided, use this decoder, otherwise just load raw bytes.

  • **kwargs – Args passed to parent constructor

config()[source]
Return type:

Dict[str, Any]

load_sample(sample)[source]

Loads the sample from the dataset.

Parameters:

sample (FilteredSample)

Return type:

T_sample

class megatron.energon.DefaultGenericWebdatasetFactory(path, *, subflavors=None, field_map=None, sample_loader=None, part_filter=None, **kwargs)[source]

Bases: BaseWebdatasetFactory[T_sample], Generic[T_sample]

Default implementation of webdataset for generic samples and the generic config interface for use with dataset.yaml.

Parameters:
  • path (EPath)

  • subflavors (Dict[str, Any] | None)

  • field_map (Dict[str, str] | None)

  • sample_loader (str | Callable[[dict], dict] | None)

  • part_filter (str | List[str] | Callable[[str], bool] | None)

__init__(path, *, subflavors=None, field_map=None, sample_loader=None, part_filter=None, **kwargs)[source]

Factory for the webdataset sample loader and basic configuration options.

Parameters:
  • subflavors (Dict[str, Any] | None) – Subflavors dictionary to set for all loaded samples.

  • field_map (Dict[str, str] | None) – Mapping from the webdataset fields to the sample fields.

  • sample_loader (str | Callable[[dict], dict] | None) – Function to load the sample from the webdataset fields. May be a string in order to load a function from a module, or a callable directly.

  • part_filter (str | List[str] | Callable[[str], bool] | None) – Filter for the parts to load. May be a string in order to load a function from a module, or a callable directly.

  • **kwargs – Args passed to parent constructor.

  • path (EPath)

config()[source]
Return type:

Dict[str, Any]

load_sample(sample)[source]

Loads the sample from the dataset.

Parameters:

sample (FilteredSample)

Return type:

T_sample

class megatron.energon.DefaultTaskEncoder(*, encoded_sample_type=None, raw_batch_type=None, batch_type=None)[source]

Bases: TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch], ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]

The default task encoder supports automagically mapping to target types. You may override any methods to customize the behavior. By default, encode_sample is the identity function, batch calls _batch with the type of the first sample, and encode_batch is also the identity function. If you set any of encoded_sample_type, raw_batch_type or batch_type, the corresponding method return that type, where it automatically maps the fields (by name) to your new type.

Parameters:
  • encoded_sample_type (Type[T_encoded_sample] | None)

  • raw_batch_type (Type[T_raw_batch] | None)

  • batch_type (Type[T_batch] | None)

__init__(*, encoded_sample_type=None, raw_batch_type=None, batch_type=None)[source]

Initialize the default task encoder. Types may be:

  • A @dataclass class: Return that typed dataclass. Field names must match the input fields.

  • A NamedTuple class: Return that typed namedtuple. Field names must match the input fields.

  • dict: Simply return the input as dict with field names as keys.

Parameters:
  • encoded_sample_type (Type[T_encoded_sample] | None) – Type of encoded samples (before batching)

  • raw_batch_type (Type[T_raw_batch] | None) – Type of the batched samples (after batching)

  • batch_type (Type[T_batch] | None) – Type of the encoded batched samples

  • cache – Cache pool to use for caching. If not provided, a no-op cache pool will be used.

batch(samples)[source]

Batch a list of samples. The default implementation uses default batching to convert to _batch_type.

Parameters:

samples (List[T_encoded_sample])

Return type:

T_raw_batch

encode_batch(batch)[source]

Encode a batch of samples. The default implementation converts to the _encoded_batch_type.

Parameters:

batch (T_raw_batch)

Return type:

T_batch | Generator[T_batch, None, None]

encode_sample(sample)[source]

Encode a single sample. The default implementation converts to the _encoded_sample_type.

Parameters:

sample (T_sample)

Return type:

T_encoded_sample | Generator[T_encoded_sample, None, None]

class megatron.energon.DirectLazy(*, ds, fname, pool, _data=None)[source]

Bases: Lazy[T]

This is not really lazy, it will just defer the dataset access to the first get().

Parameters:
get(sample=None)[source]

Get the lazy data now and adds no source info to the sample.

Parameters:

sample (Any)

Return type:

T

class megatron.energon.EpochizeDataset(dataset, length, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

Uses the base dataset, and creates one epoch, which has length samples. Keeps the underlying dataset iterator alive over epochs (i.e. if it is an infinite dataset, it will keep the state). Repeats the underlying dataset if the iterator is exhausted.

Parameters:
__init__(dataset, length, worker_config)[source]

Create the epochized dataset.

Parameters:
  • dataset (SavableDataset[T_sample]) – The source dataset (possibly infinite)

  • length (int) – Number of samples to iterate before iteration stops (i.e. one epoch). When iteration continues, the original dataset iterator is resumed and does only restart if exhausted.

  • worker_config (WorkerConfig) – Configuration for the workers.

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

length: int
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.FileCacheLazy(*, ds, fname, pool, entry, _data=None)[source]

Bases: Lazy[T]

Represents a reference to a background prefetch.

Parameters:
entry: _PendingTask
get(sample=None)[source]

Returns the data and adds the source info to the sample. If the background job hasn’t started, we cancel it, do a direct read, and remove ourselves from the pool’s references. Otherwise, we wait for the job to finish, read from cache, and remove ourselves.

Parameters:

sample (Any)

Return type:

T

class megatron.energon.FileStore[source]

Bases: Generic[T]

Base type for a dataset that can be accessed randomly by sample key.

get(key, sample=None)[source]

Returns the data for the given key and adds the source info to the sample.

Parameters:
  • key (str)

  • sample (Any)

Return type:

Any

abstractmethod get_path()[source]

Returns the path to the dataset.

Return type:

str

class megatron.energon.FileStoreCachePool(*, parent_cache_dir=None, num_workers=8, max_cache_size_gbytes=1024, max_cache_count=10000000, method='raw')[source]

Bases: CachePool, ForkMixin

Manages a thread pool to pre-fetch data onto an SSD cache. Each (ds, fname) has one Future (one read). Multiple requests share that same future. We track usage with a refcount.

To avoid multi-process collisions, we generate a random subfolder for each instance.

Parameters:
  • parent_cache_dir (Path | None)

  • num_workers (int)

  • max_cache_size_gbytes (float)

  • max_cache_count (int)

  • method (Literal['raw', 'pickle'])

__init__(*, parent_cache_dir=None, num_workers=8, max_cache_size_gbytes=1024, max_cache_count=10000000, method='raw')[source]

Initialize the cache pool.

Parameters:
  • parent_cache_dir (Path | None) – The parent directory for the cache.

  • num_workers (int) – The number of worker threads to use for copying the data to the cache for lazy loading.

  • max_cache_size_gbytes (float) – The maximum size of the cache in gigabytes. If the cache exceeds this size, the prefetching will wait until the cache is below this size.

  • max_cache_count (int) – The maximum number of files in the cache. If the cache exceeds this number, the prefetching will wait until the cache is below this number.

  • method (Literal['raw', 'pickle']) – The method to use for caching. “raw” store the non-decoded raw data. “pickle”: first decode the data and then store the pickled data.

cache_dir: Path
close()[source]

Shutdown the pool, wait for tasks, and clear our structures.

Return type:

None

current_cache_count: int
current_cache_size: int
get(ds, fname, sample=None)[source]

Synchronous read from the dataset (no cache usage).

Parameters:
Return type:

Any

get_lazy(ds, fname)[source]

Schedule a background pre-fetch. If multiple calls come in for the same (ds, fname), they’ll share the same Future and increment reference counts.

Parameters:
Return type:

FileCacheLazy

max_cache_count: int
max_cache_size: int
method: Literal['raw', 'pickle']
class megatron.energon.FileStoreDecoder[source]

Bases: ABC

Abstract base class for decoders.

abstractmethod decode(fname, data)[source]

Decode the specified file (i.e. path/key.ext). The extension is used to select the decoder.

Parameters:
  • fname (str) – The file name of the file to decode.

  • raw – The raw bytes of the file to decode.

  • data (bytes)

Returns:

The decoded field’s data.

Return type:

Any

class megatron.energon.FilterDataset(dataset, *, filter_fn, filter_fn_config=None, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

This dataset wrapper applies a custom filter function to each sample and does not yield filtered samples.

Parameters:
  • dataset (SavableDataset[T_sample])

  • filter_fn (Callable[[T_sample], bool])

  • filter_fn_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • worker_config (WorkerConfig)

__init__(dataset, *, filter_fn, filter_fn_config=None, worker_config)[source]

Construct a MapDataset.

Parameters:
  • dataset (SavableDataset[T_sample]) – The input dataset to wrap

  • filter_fn (Callable[[T_sample], bool]) – The function to apply to each sample. If it returns True, the sample is accepted.

  • filter_fn_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None) – Configuration for the filter function. If callable, it should return the configuration. Defaults to None.

  • worker_config (WorkerConfig) – Configuration for the workers.

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

filter_fn: Callable[[T_sample], bool]
filter_fn_config: Dict[str, Any] | Callable[[], Dict[str, Any]] | None
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.GcDataset(dataset, *, worker_config, every_n_iter=10, freeze=True)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

Applies a garbage collection step. This is needed, because python garbage collection does not work well with very large objects, such as tensors. This case happens, if there are a few hundred objects created and released every epoch (some of them being (large) tensors), where a lot of them are alive at the same time, but released later. In that case, those objects may end up in gc generation 2, where they may live until a lot of objects have been created, until automatic garbage collection of gen2 is actually triggered. To avoid this memory leak, gc.collect() is best to be called regularly. In addition, if gc.freeze() is used before the loop, it will remove the objects currently alive from garbage collection checks, thus making the gc faster.

Parameters:
__init__(dataset, *, worker_config, every_n_iter=10, freeze=True)[source]

Construct a GcDataset, which applies garbage collection every every_n_iter iterations.

Parameters:
  • dataset (SavableDataset[T_sample]) – The input dataset to wrap

  • every_n_iter (int) – How often to perform garbage collection

  • freeze (bool) – If true, run gc.freeze() before the loop, and gc.unfreeze() after the loop. This will speed up garbage collection, but will keep all initially alive objects alive until the end of the loop (i.e. if the dataset state was restored, that state will be saved as well).

  • worker_config (WorkerConfig)

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

every_n_iter: int
freeze: bool
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.GroupBatchDataset(dataset, fixed_batch_size, sample_group_key, batcher, *, batcher_stateless=False, batcher_config=None, drop_last=False, error_handler=<function log_exception>, failure_tolerance=100, worker_config)[source]

Bases: BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]

This dataset wrapper transforms a dataset of samples into a dataset of batches, grouped by some criterion. The length is not correct, as this function can not predict the number of batches as there is no fixed batch size, instead it returns the inner dataset size. An example use case is: Image-Text samples, which are to be grouped by the image size into three size categories (e.g. 128x128, 256x256, 512x512) for efficient augmentation and batching.

Parameters:
  • dataset (SavableDataset[T_batch_sample])

  • fixed_batch_size (int | None)

  • sample_group_key (Callable[[T_batch_sample], Tuple[Hashable, int | None]])

  • batcher (Callable[[List[T_batch_sample]], T_batch])

  • batcher_stateless (bool)

  • batcher_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • drop_last (bool)

  • error_handler (Callable[[Exception, List[T_batch_sample], list[SourceInfo]], None])

  • failure_tolerance (int | None)

  • worker_config (WorkerConfig)

__init__(dataset, fixed_batch_size, sample_group_key, batcher, *, batcher_stateless=False, batcher_config=None, drop_last=False, error_handler=<function log_exception>, failure_tolerance=100, worker_config)[source]

Construct a GroupBatchDataset.

Parameters:
  • dataset (SavableDataset[T_batch_sample]) – The input dataset to wrap

  • fixed_batch_size (int | None) – Fixed batch size to use for all buckets. If None, the batch size is determined by the sample_group_key function.

  • sample_group_key (Callable[[T_batch_sample], Tuple[Hashable, int | None]]) – Function which determines the bucket of a sample.

  • batcher (Callable[[List[T_batch_sample]], T_batch]) – Function which combines separate samples into a single object. May raise megatron.energon.SkipSample to skip a sample.

  • drop_last (bool) – If True, the last batch is dropped if it is smaller than the batch size.

  • error_handler (Callable[[Exception, List[T_batch_sample], list[SourceInfo]], None]) – Handler for errors. Defaults to logging and ignoring the exception.

  • failure_tolerance (int | None) – The number of consecutive failures after which the dataset is considered broken.

  • worker_config (WorkerConfig) – Configuration for the workers.

  • batcher_stateless (bool)

  • batcher_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

batcher: Callable[[List[T_batch_sample]], T_batch]
can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

drop_last: bool
error_handler: Callable[[Exception, List[T_batch_sample], list[SourceInfo]], None]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(index)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

index (Tuple[str | int | tuple, ...])

Return type:

T_batch

restore_state(state)[source]

Restores the state of the dataset. This will restore the state of all fields in the _savable_fields tuple. Can only be called in a worker process.

Parameters:

state (FlexState) – The state of the dataset as savable object. If None, restore initial state.

Return type:

None

sample_group_key: Callable[[T_batch_sample], Tuple[Hashable, int | None]]
save_state()[source]

Saves the state of the dataset. This will save and return the state of all fields in the _savable_fields tuple. Can only be called in a worker process.

Return type:

FlexState

class megatron.energon.ImageClassificationSample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, image, label=None, label_name=None)[source]

Bases: Sample

Sample type for classifying an image.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • image (Tensor)

  • label (int | None)

  • label_name (str | None)

image: Tensor

The input image tensor in the shape (C, H, W)

label: int | None

The class label of the image

label_name: str | None

The class label of the image

class megatron.energon.ImageClassificationWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[ImageClassificationSample]

Parameters:

path (EPath)

class megatron.energon.ImageSample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, image)[source]

Bases: Sample

Sample type for an image, e.g. for image reconstruction.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • image (Tensor)

image: Tensor

The input image tensor in the shape (C, H, W)

class megatron.energon.ImageWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[ImageSample]

Parameters:

path (EPath)

class megatron.energon.InterleavedSample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, sequence)[source]

Bases: Sample

Sample type for interleaved media such as text with images.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • sequence (List[Tensor | str])

sequence: List[Tensor | str]

The interleaved media (either torch.tensor for an image, or str for text)

class megatron.energon.InterleavedWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[InterleavedSample]

Parameters:

path (EPath)

class megatron.energon.IterMapDataset(dataset, iter_map_fn, *, len_map_fn=<function IterMapDataset.<lambda>>, error_handler=<function log_exception>, stateless_iter_fn=False, iter_map_fn_config=None, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]

This dataset wrapper applies a custom function to transform the stream of samples and yield a new stream of samples. If used in a savable dataset context, it is critical, that iter_map_fn is either stateless, or that the state of the iter_map_fn is saved and restored externally.

Parameters:
  • dataset (SavableDataset[T_sample])

  • iter_map_fn (Callable[[Iterator[T_sample]], Iterator[T_sample_out]])

  • len_map_fn (Callable[[int], int])

  • error_handler (Callable[[Exception, T_sample | None, list[SourceInfo]], None])

  • stateless_iter_fn (bool)

  • iter_map_fn_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • worker_config (WorkerConfig)

__init__(dataset, iter_map_fn, *, len_map_fn=<function IterMapDataset.<lambda>>, error_handler=<function log_exception>, stateless_iter_fn=False, iter_map_fn_config=None, worker_config)[source]

Construct a IterMapDataset. For saving and restoring samples, the iter_map_fn must only yield 0 or 1 sample per iterated sample.

Parameters:
  • dataset (SavableDataset[T_sample]) – The input dataset to wrap

  • iter_map_fn (Callable[[Iterator[T_sample]], Iterator[T_sample_out]]) – The function to apply to the stream of samples. Returns a new stream of samples. If savability should be preserved, this function should be stateless.

  • len_map_fn (Callable[[int], int]) – The function to apply to the length of the dataset. Returns the new (approximate) length of the resulting stream of samples based on the original length.

  • error_handler (Callable[[Exception, T_sample | None, list[SourceInfo]], None]) – Handler for errors. Defaults to logging and ignoring the exception.

  • stateless_iter_fn (bool) – If true, assume the iter_map_fn is deterministic and stateless (it does not aggregate samples (thus key for random access can propagate to inner dataset), yielding zero or multiple samples per fetched sample is fine). Defaults to False.

  • iter_map_fn_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None) – Configuration for the iter_map_fn function. If callable, it should return the configuration. Defaults to None.

  • worker_config (WorkerConfig) – Configuration for the workers.

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

error_handler: Callable[[Exception, T_sample | None, list[SourceInfo]], None]
iter_map_fn: Callable[[Iterator[T_sample]], Iterator[T_sample_out]]
iter_map_fn_config: Dict[str, Any] | Callable[[], Dict[str, Any]] | None
len_map_fn: Callable[[int], int]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(restore_key)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

restore_key (Tuple[str | int | tuple, ...])

Return type:

T_sample

stateless_iter_fn: bool
class megatron.energon.JoinedWebdatasetFactory(inner_datasets, *, training, worker_config, shuffle_over_epochs=1, parallel_shard_iters=None, max_samples_per_sequence=None, join_index, joiner, handler=<function reraise_exception>)[source]

Bases: BaseCoreDatasetFactory[T_sample], Sharder, ErrorHandler[T_sample], Generic[T_sample], ABC

Base class for all webdataset loaders. Applies proper sharding across workers. Can join multiple datasets.

Parameters:
  • inner_datasets (List[BaseWebdatasetFactory])

  • training (bool)

  • worker_config (WorkerConfig)

  • shuffle_over_epochs (int | None)

  • parallel_shard_iters (int | None)

  • max_samples_per_sequence (int | None)

  • join_index (EPath)

  • joiner (Type[T_sample] | Callable[[...], T_sample])

  • handler (Callable[[Exception, str | None, list[SourceInfo] | None], None])

__init__(inner_datasets, *, training, worker_config, shuffle_over_epochs=1, parallel_shard_iters=None, max_samples_per_sequence=None, join_index, joiner, handler=<function reraise_exception>)[source]

Constructs the loader for a joined webdataset. The samples from the inner datasets are joined into a single sample using the joiner function.

Parameters:
  • inner_dataset – The inner datasets. Must be loaded internally with _is_composed=True. Either a list (*args for joiner) or a dict (**kwargs for joiner) of datasets, where the samples will be passed to the joiner function as *args or **kwargs.

  • training (bool) – If true, apply shuffling and loop the dataset.

  • worker_config (WorkerConfig) – Configuration for the workers.

  • shuffle_over_epochs (int | None) – Only effective if training=True. How many epochs to shuffle over if training. If = 1, every sample is seen exactly once per epoch. If > 1, samples (or rather shard slices) are shuffled within this number of epochs (i.e. randomly selected without replacement). If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices are drawn with replacement).

  • parallel_shard_iters (int | None) – Number of parallel opened shards per worker, shuffling between.

  • max_samples_per_sequence (int | None) – Maximum number of samples per sequence (=how many samples will be sequentially iterated).

  • join_index (EPath) – Path to the join index file. Only required for join_method=”left”.

  • joiner (Type[T_sample] | Callable[[...], T_sample]) – Type of the joined samples or a method for joining the samples.

  • handler (Callable[[Exception, str | None, list[SourceInfo] | None], None]) – Exception handler. Args: (exception, key).

  • inner_datasets (Sequence[BaseWebdatasetFactory] | Mapping[str, BaseWebdatasetFactory])

as_file_store()[source]

Returns the dataset as a random access dataset.

Return type:

FileStore

build(worker_rotation_offset=0)[source]

Builds the dataset.

Parameters:

worker_rotation_offset (int)

Return type:

SavableDataset[T_sample]

config()[source]
Return type:

Dict[str, Any]

inner_dataset_keys: List[str] | None
inner_datasets: List[BaseWebdatasetFactory]
join_index: EPath
load_sample(samples)[source]
Parameters:

samples (RawSampleData)

Return type:

T_sample

max_samples_per_sequence: int | None
parallel_shard_iters: int | None
part_datasets: SavableDataset[T_sample]
property paths: List[EPath]
shards: List[Sequence[ShardInfo]]
shuffle_over_epochs: int | None = 1
training: bool
worker_config: WorkerConfig
class megatron.energon.Lazy(*, ds, fname, pool)[source]

Bases: Generic[T]

Abstract base class for lazy references to data.

Parameters:
ds: FileStore
fname: str
abstractmethod get(sample=None)[source]

Get the lazy data now and adds the source info to the sample.

Parameters:

sample (Any)

Return type:

T

pool: CachePool
class megatron.energon.LimitDataset(dataset, length, *, reset_after_epoch=False, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

Limits the length of the dataset.

Parameters:
__init__(dataset, length, *, reset_after_epoch=False, worker_config)[source]

Limits the length of the dataset.

Parameters:
  • dataset (SavableDataset[T_sample]) – The dataset to limit

  • length (int) – The length to limit to

  • reset_after_epoch (bool) – If true, reset the underlying dataset after one epoch.

  • worker_config (WorkerConfig) – Configuration for the workers.

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

current_offset: int
length: int
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

worker_has_samples()[source]

Returns True if the worker’s split has samples. This is used to determine if this dataset yields anything.

Return type:

bool

class megatron.energon.LogSampleDataset(dataset, mode, worker_config, get_keys_fn=<function default_get_keys>)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

This dataset logs every yielded sample to the debug logs.

Parameters:
  • dataset (SavableDataset[T_sample])

  • mode (Literal['train', 'val'])

  • worker_config (WorkerConfig)

  • get_keys_fn (Callable[[T_sample], List[str] | None])

__init__(dataset, mode, worker_config, get_keys_fn=<function default_get_keys>)[source]

Construct the log sample dataset, which logs every yielded sample to the debug logs.

Parameters:
  • dataset (SavableDataset[T_sample]) – The input dataset to wrap

  • mode (Literal['train', 'val'])

  • worker_config (WorkerConfig)

  • get_keys_fn (Callable[[T_sample], List[str] | None])

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

get_keys_fn: Callable[[T_sample], List[str] | None]
mode: Literal['train', 'val']
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.MapDataset(dataset, map_fn, *, error_handler=<function log_exception>, stateless_map_fn=False, map_fn_config=None, failure_tolerance=100, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]

This dataset wrapper applies a custom function to transform each sample.

Parameters:
  • dataset (SavableDataset[T_sample])

  • map_fn (Callable[[T_sample], T_sample_out | Generator[T_sample_out, None, None]])

  • error_handler (Callable[[Exception, T_sample, list[SourceInfo]], None])

  • stateless_map_fn (bool)

  • map_fn_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • failure_tolerance (int | None)

  • worker_config (WorkerConfig)

__init__(dataset, map_fn, *, error_handler=<function log_exception>, stateless_map_fn=False, map_fn_config=None, failure_tolerance=100, worker_config)[source]

Construct a MapDataset.

If this should be savable, the map_fn must only return a sample, or a generator yielding 0 or 1 sample per input sample. Otherwise this will be broken (see IterMapDataset).

Parameters:
  • dataset (SavableDataset[T_sample]) – The input dataset to wrap

  • map_fn (Callable[[T_sample], T_sample_out | Generator[T_sample_out, None, None]]) – The function to apply to each sample. May raise megatron.energon.SkipSample to skip a sample. Alternatively, may return a generator to yield multiple or no samples.

  • error_handler (Callable[[Exception, T_sample, list[SourceInfo]], None]) – Handler for errors. Defaults to logging and ignoring the exception.

  • stateless_map_fn (bool) – If true, the map_fn is deterministic and stateless (thus key for random access can propagate to inner dataset). Defaults to False.

  • map_fn_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None) – Configuration for the map_fn function. If callable, it should return the configuration. Defaults to None.

  • failure_tolerance (int | None) – The number of consecutive failures after which the dataset is considered broken.

  • worker_config (WorkerConfig) – Worker configuration.

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

error_handler: Callable[[Exception, T_sample, list[SourceInfo]], None]
map_fn: Callable[[T_sample], T_sample_out | Generator[T_sample_out, None, None]]
map_fn_config: Dict[str, Any] | Callable[[], Dict[str, Any]] | None
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(restore_key)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

restore_key (Tuple[str | int | tuple, ...])

Return type:

T_sample_out

stateless_map_fn: bool
class megatron.energon.Metadataset(path, splits)[source]

Bases: DatasetLoaderInterface

Main entry for metadataset.

Parameters:
  • path (EPath | str)

  • splits (Dict[str, MetadatasetBlender])

__init__(path, splits)[source]

Create the metadataset

Parameters:
  • path (EPath | str)

  • splits (Dict[str, MetadatasetBlender])

get_datasets(*, training, split_part, worker_config, subflavors=None, shuffle_over_epochs_multiplier=1, **kwargs)[source]

Calls megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset) for all innermost datasets and resolves their relative weights to absolute weights.

Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (Literal['train', 'val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • subflavors (Dict[str, Any] | None) – Set the default subflavors for all datasets.

  • shuffle_over_epochs_multiplier (int | None) – Multiply the inner datasets shuffle_over_epochs(_multiplier) by this factor. E.g. if the inner dataset has shuffle_over_epochs_multiplier=2 and this function has shuffle_over_epochs_multiplier=3, the inner dataset will be shuffled over 6 epochs. Shuffling over n epochs guarantees that each sample is seen exactly n times in n epochs of the inner dataset. Use -1 for shuffling over an infinite number of epochs (effectively, this will draw shard slices with replacement).

  • **kwargs – Additional arguments to the dataset constructor.

Returns:

The dataset blending mode and the instantiated core datasets with their weights/repetitions.

Return type:

LoadedDatasetList

post_initialize(mds_path=None)[source]

Called to finally initialize the dataset.

Parameters:

mds_path (EPath | None)

class megatron.energon.MetadatasetV2(*, path: megatron.energon.epathlib.epath.EPath, splits: Dict[str, megatron.energon.metadataset.metadataset_v2.MetadatasetBlend | megatron.energon.metadataset.metadataset_v2.MetadatasetBlendEpochized | megatron.energon.metadataset.metadataset_v2.MetadatasetJoin | megatron.energon.metadataset.metadataset_v2.DatasetReference])[source]

Bases: DatasetLoaderInterface

Parameters:
  • path (EPath)

  • splits (Dict[str, MetadatasetBlend | MetadatasetBlendEpochized | MetadatasetJoin | DatasetReference])

get_datasets(*, training, split_part, worker_config, subflavors=None, shuffle_over_epochs_multiplier=1, **kwargs)[source]

Calls megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset) for all innermost datasets and resolves their relative weights to absolute weights.

Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (Literal['train', 'val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • subflavors (Dict[str, Any] | None) – Set the default subflavors for all datasets.

  • shuffle_over_epochs_multiplier (int | None) – Multiply the inner datasets shuffle_over_epochs(_multiplier) by this factor. E.g. if the inner dataset has shuffle_over_epochs_multiplier=2 and this function has shuffle_over_epochs_multiplier=3, the inner dataset will be shuffled over 6 epochs. Shuffling over n epochs guarantees that each sample is seen exactly n times in n epochs of the inner dataset. Use -1 for shuffling over an infinite number of epochs (effectively, this will draw shard slices with replacement).

  • **kwargs – Additional arguments to the dataset constructor.

Returns:

The dataset blending mode and the instantiated core datasets with their weights/repetitions.

Return type:

LoadedDatasetList

path: EPath
post_initialize(mds_path=None)[source]

Called to finally initialize the dataset.

Parameters:

mds_path (EPath | None)

prepare(split_part=None)[source]

Prepares the loader by creating caches and other necessary structures on disk.

Parameters:

split_part (str | None) – Name of the split to load.

Returns:

List of paths to the cache paths. This is used for cleanup.

Return type:

Sequence[EPath]

splits: Dict[str, MetadatasetBlend | MetadatasetBlendEpochized | MetadatasetJoin | DatasetReference]
class megatron.energon.MixBatchDataset(*dataset_weights, batch_size, batch_mix_fn=<function MixBatchDataset.<lambda>>, worker_config)[source]

Bases: BaseWrapperDataset[T_batch_in, T_batch], Generic[T_batch_in, T_batch]

This dataset wrapper blends multiple iterable datasets together give a weight. The datasets may be infinite. This dataset is always infinite. Effectively combines megatron.energon.BlendDataset and megatron.energon.BatchDataset.

Parameters:
  • dataset_weights (Tuple[SavableDataset[T_batch_in], float])

  • batch_size (int)

  • batch_mix_fn (Callable[[List[T_batch_in]], T_batch | Generator[T_batch, None, None]])

  • worker_config (WorkerConfig)

__init__(*dataset_weights, batch_size, batch_mix_fn=<function MixBatchDataset.<lambda>>, worker_config)[source]

Construct a BlendDataset.

Parameters:
  • dataset_weights (Tuple[SavableDataset[T_batch_in], float]) – Each argument should be a tuple of (dataset, weight) with a weight between 0 and 1. The output samples are sampled from the input datasets with the given probabilities. The datasets should have a batch size of 1, otherwise the whole batches will be sampled.

  • batch_size (int) – The batch size to output.

  • batch_mix_fn (Callable[[List[T_batch_in]], T_batch | Generator[T_batch, None, None]]) – A function that takes a list of samples from the input datasets and returns a batch sample. The default implementation returns a list of batches. For homogeneous datasets, it is recommended to use the megatron.energon.homogeneous_concat_mix() which concatenates the batches. May raise megatron.energon.SkipSample to skip a sample. May also return a generator, which will be iterated over to produce batches.

  • worker_config (WorkerConfig) – Configuration for the workers.

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.MockLazy(fname, get_fn)[source]

Bases: Lazy[T]

Mock object, which can be used as a Lazy. Allows the user to set the function to retrieve the data. May be used to create a Lazy that is initialized from a function.

Parameters:
  • fname (str)

  • get_fn (Callable[[str], T])

__init__(fname, get_fn)[source]

Initialize the MockLazy object.

Parameters:
  • fname (str) – The file name of the mock object (may be used by the user).

  • get_fn (Callable[[str], T]) – The function to retrieve/generate the data.

get(sample=None)[source]

Get the lazy data now and adds no source info to the sample.

Parameters:

sample (Any)

Return type:

T

get_fn: Callable[[str], T]
class megatron.energon.MultiChoiceVQASample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, image, context, choices=None, correct_choice_idx=0)[source]

Bases: Sample

Sample type for visual question answering.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • image (Tensor)

  • context (str)

  • choices (List[str] | None)

  • correct_choice_idx (int)

choices: List[str] | None

The candidate answers.

context: str

The context/question for the image

correct_choice_idx: int

The index of the correct answer.

image: Tensor

The input image tensor in the shape (C, H, W)

class megatron.energon.MultiChoiceVQAWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[MultiChoiceVQASample]

Parameters:

path (EPath)

class megatron.energon.NoCachePool[source]

Bases: CachePool

A pass-through cache pool that does not cache anything.

close()[source]

Close the cache pool.

Return type:

None

get(ds, fname, sample=None)[source]

Get the data for a given file and adds the source info to the sample.

Parameters:
Return type:

Any

get_lazy(ds, fname)[source]

Get a lazy reference to the data for a given file.

Parameters:
Return type:

DirectLazy

class megatron.energon.OCRSample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, image, text, block_boxes=None, block_classes=None, block_text=None, lines_boxes=None, lines_text=None, words_boxes=None, words_text=None, chars_boxes=None, chars_text=None)[source]

Bases: Sample

Sample type for optical character recognition.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • image (Tensor)

  • text (str)

  • block_boxes (Tensor | None)

  • block_classes (Tensor | List[str] | None)

  • block_text (List[str] | None)

  • lines_boxes (Tensor | None)

  • lines_text (List[str] | None)

  • words_boxes (Tensor | None)

  • words_text (List[str] | None)

  • chars_boxes (Tensor | None)

  • chars_text (List[str] | None)

block_boxes: Tensor | None

The bounding boxes of the blocks in the image float(N, 4|5<x, y, w, h>)

block_classes: Tensor | List[str] | None

The classes of the blocks in the image int(N, 1<block_class>)

block_text: List[str] | None

The text contained in each block (N,)

chars_boxes: Tensor | None

The bounding boxes of the chars in the image float(N, 4|5<x, y, w, h[, confidence]>)

chars_text: List[str] | None

The character contained in each char (N,)

image: Tensor

The input image tensor in the shape (C, H, W)

lines_boxes: Tensor | None

The bounding boxes of the lines in the image float(N, 4|5<x, y, w, h[, confidence]>)

lines_text: List[str] | None

The text contained in each line (N,)

text: str

The text contained in the image

words_boxes: Tensor | None

The bounding boxes of the words in the image float(N, 4|5<x, y, w, h[, confidence]>)

words_text: List[str] | None

The text contained in each word (N,)

class megatron.energon.OCRWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[OCRSample]

Parameters:

path (EPath)

class megatron.energon.PackingDataset(dataset, buffer_size, pre_packer, final_packer, *, final_packer_stateless=False, sample_encoder=None, sample_encoder_stateless=False, packer_config=None, error_handler=<function log_exception>, pre_packer_failure_tolerance=100, final_packer_failure_tolerance=100, sample_encoder_failure_tolerance=100, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_encoded_sample, T_batch_sample], Generic[T_sample, T_encoded_sample, T_batch_sample]

This dataset wrapper transforms samples of a dataset into chunks/packs of samples, which are then combined into a batch.

Parameters:
  • dataset (SavableDataset[T_sample])

  • buffer_size (int)

  • pre_packer (Callable[[List[T_sample]], List[List[T_sample]]])

  • final_packer (Callable[[List[T_encoded_sample]], T_batch_sample])

  • final_packer_stateless (bool)

  • sample_encoder (Callable[[T_sample], T_encoded_sample] | None)

  • sample_encoder_stateless (bool)

  • packer_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • error_handler (Callable[[Exception, List[T_sample], list[SourceInfo]], None])

  • pre_packer_failure_tolerance (int | None)

  • final_packer_failure_tolerance (int | None)

  • sample_encoder_failure_tolerance (int | None)

  • worker_config (WorkerConfig)

__init__(dataset, buffer_size, pre_packer, final_packer, *, final_packer_stateless=False, sample_encoder=None, sample_encoder_stateless=False, packer_config=None, error_handler=<function log_exception>, pre_packer_failure_tolerance=100, final_packer_failure_tolerance=100, sample_encoder_failure_tolerance=100, worker_config)[source]

Construct a PackingDataset which is used for sequence packing. Using a pre_packer and final_packer, it buffers the incoming samples, groups them together based on the logic provided by the pre_packer, and then (using the final_packer) combines each group into a packed single sample also called a “pack” or a “packed sequence”.

Parameters:
  • dataset (SavableDataset[T_sample]) – The input dataset to wrap

  • buffer_size (int) – The desired size of the input buffer for pre packing. Last buffer of a dataset may be smaller.

  • pre_packer (Callable[[List[T_sample]], List[List[T_sample]]]) – Function which selects samples from the buffer to be packed together. May raise megatron.energon.SkipSample to skip a buffer.

  • final_packer (Callable[[List[T_encoded_sample]], T_batch_sample]) – Function which combines the selected samples into a single sample.

  • final_packer_stateless (bool) – If True, the final_packer is stateless, thus samples can be stored/restored.

  • sample_encoder (Callable[[List[T_sample]], T_encoded_sample] | None) – Function which encodes the samples.

  • sample_encoder_stateless (bool) – If True, the sample_encoder is stateless, thus samples can be stored/restored.

  • packer_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None) – Configuration for the (pre|final)_packer functions. If callable, it should return the configuration. Defaults to None.

  • error_handler (Callable[[Exception, List[T_sample], list[SourceInfo]], None]) – Function which handles exceptions raised by the batcher. The default implementation logs the exception.

  • pre_packer_failure_tolerance (int | None) – Maximum number of pre-packer failures before raising an error.

  • final_packer_failure_tolerance (int | None) – Maximum number of final-packer failures before raising an error.

  • sample_encoder_failure_tolerance (int | None) – Maximum number of sample-encoder failures before raising an error.

  • worker_config (WorkerConfig) – Configuration for the workers.

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

buffer_size: int
can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

error_handler: Callable[[Exception, List[T_sample], list[SourceInfo]], None]
final_packer: Callable[[List[T_encoded_sample]], T_batch_sample]
final_packer_stateless: bool
packer_config: Dict[str, Any] | Callable[[], Dict[str, Any]] | None
pre_packer: Callable[[List[T_sample]], List[List[T_sample]]]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(restore_key)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

restore_key (Any)

Return type:

T_sample

sample_encoder: Callable[[T_sample], T_encoded_sample] | None
sample_encoder_stateless: bool
class megatron.energon.RepeatDataset(dataset, *, repeats=None, restart=True, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

This dataset repeats the inner dataset indefinitely or a specific number of repeats.

Parameters:
__init__(dataset, *, repeats=None, restart=True, worker_config)[source]

Construct a RepeatDataset.

Parameters:
  • dataset (SavableDataset[T_sample]) – The input dataset to repeat.

  • repeats (int | float | None) – Number of repeats, None for indefinitely repeating.

  • restart (bool) – If true, restart the underlying dataset after iterating once through the repeats if repeats is set to an integer, but still stop iterating.

  • worker_config (WorkerConfig) – Configuration for the workers.

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

repeats: int | float | None
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.Sample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None)[source]

Bases: ABC, PinMemoryMixin, ExtendableDataclassMixin

An abstract base class for one element of a batch. Each task should derive a specific subclass as a @dataclass, like megatron.energon.CaptioningBatchSample, and add the input and output fields as needed for training.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

classmethod derive_from(base_sample, **kwargs)[source]

Uses the base fields of Sample from base_sample (i.e. __key__, __restore_key__, __subflavors__, __sources__) and creates a new sample with the kwargs as fields. This is useful for creating new samples, while keeping the metadata of the base sample.

Parameters:
  • base_sample (Sample) – The base sample to copy the base fields / metadata from.

  • kwargs – The fields of the new sample.

Returns:

The new sample.

Return type:

T_sample

classmethod from_joined(*args, **kwargs)[source]

Creates a sample from joined samples. The samples are either passed as positional arguments or as keyword arguments. The first sample is the primary sample, which is used to initialize the key and subflavors.

In the default implementation, the joined samples’ fields will be joined together, such that latter joined samples will update the fields last (i.e. take precedence), except for the key and subflavors. The restore key is later set externally.

Parameters:
  • args (Sample | None) – The samples to join (either this or kwargs is specified).

  • kwargs (Sample | None) – The samples to join (either this or args is specified). Not supported for the default implementation. Overwriting implementations may use this.

Returns:

The joined constructed sample.

Return type:

T_sample

class megatron.energon.SampleDecoder(*, image_decode='torchrgb', av_decode='AVDecoder', video_decode_audio=False, guess_content=False)[source]

Bases: FileStoreDecoder

The default decoder for webdataset samples.

Parameters:
  • image_decode (Literal['l8', 'rgb8', 'rgba8', 'l', 'rgb', 'rgba', 'torchl8', 'torchrgb8', 'torchrgba8', 'torchl', 'torchrgb', 'torch', 'torchrgba', 'pill', 'pil', 'pilrgb', 'pilrgba'])

  • av_decode (Literal['torch', 'AVDecoder', 'pyav'])

  • video_decode_audio (bool)

  • guess_content (bool)

__init__(*, image_decode='torchrgb', av_decode='AVDecoder', video_decode_audio=False, guess_content=False)[source]
Parameters:
  • image_decode (Literal['l8', 'rgb8', 'rgba8', 'l', 'rgb', 'rgba', 'torchl8', 'torchrgb8', 'torchrgba8', 'torchl', 'torchrgb', 'torch', 'torchrgba', 'pill', 'pil', 'pilrgb', 'pilrgba']) – This defines the decoding results.

  • av_decode (Literal['torch', 'AVDecoder', 'pyav']) – If “AVDecoder”, returns an AVDecoder instance for flexible decoding. If “torch”, returns decoded VideoData.

  • video_decode_audio (bool) – Whether to decode audio from video files.

  • guess_content (bool) – Whether to guess the contents of the file using the filetype package.

config()[source]
Return type:

dict

decode(fname, raw)[source]

Decode the specified file (i.e. path/key.ext). The extension is used to select the decoder.

Parameters:
  • fname (str) – The file name of the file to decode.

  • raw (bytes) – The raw bytes of the file to decode.

Returns:

The decoded field’s data.

Return type:

Any

class megatron.energon.SavableDataLoader(dataset, *, checkpoint_every_sec=60, checkpoint_every_min_n_samples=None, n_checkpoints=None, gc_collect_every_n_steps=10, gc_freeze_at_start=True, prefetch_factor=2, cache_pool=None, watchdog_timeout_seconds=60, watchdog_initial_timeout_seconds=None, fail_on_timeout=False)[source]

Bases: DataLoader[T], Generic[T]

DataLoader that supports saving and restoring the state of the dataset. When restoring, the dataloader and dataset must be instantiated with the exactly same parameters.

How this works (for no worker processes)

  1. The state of the dataset is saved using megatron.energon.SavableDataset.save_state()

  2. (for compatibility) The state of the dataset is converted to using inner arrays using megatron.energon.SavableDataset.merge_states().

  3. The state can be restored using megatron.energon.SavableDataset.restore_state() given the previously saved (and merged) state.

How this works (for worker processes)

  • First issue is, that worker processes work with internal queues between processes to pass loaded samples to the main process (also to perform collating). This means that the whole state of the dataset is not directly accessible from the main process.

  • To solve this issue, the dataset regularly saves a checkpoint of its state to be able to resume from that state (and skip the samples that have already been yielded).

  • To have a consistent state, the sample index from the latest yielded samples is saved for all worker instances. Thus, the main process knows exactly which sample indexes should come next from which worker.

  • Internally, pytorch iterates through the workers in order to retrieve the next worker’s samples. Unfortunately, that next worker index cannot be restored in pytorch’s dataloader, thus the workers are shifted internally by that offset (see megatron.energon.WorkerConfig.worker_id_offset).

  1. The dataset is wrapped in a megatron.energon.SavableDatasetWrapper. This allows the main process to communicate with the worker and send commands to the workers and retrieve the results.

  2. The state of the dataset is saved using megatron.energon.SavableDatasetWrapper.get_checkpoint(). This gives the last checkpoint from the requested sample index and stores the offset (i.e. number of samples to skip) from that checkpoint.

  3. The state is merged using megatron.energon.SavableDatasetWrapper.merge_checkpoints(). This merges the states of all workers and returns a single state that can be used to restore the state of the dataset.

  4. The state can be restored using megatron.energon.SavableDatasetWrapper.restore_state() before a worker is started, such that all workers initially receive the same state array. The worker firstly sets the worker index offset, then uses its (shifted) own index to get its required state from the merged state array.

__init__(dataset, *, checkpoint_every_sec=60, checkpoint_every_min_n_samples=None, n_checkpoints=None, gc_collect_every_n_steps=10, gc_freeze_at_start=True, prefetch_factor=2, cache_pool=None, watchdog_timeout_seconds=60, watchdog_initial_timeout_seconds=None, fail_on_timeout=False)[source]

Create the dataloader supporting saving and restoring the state.

Parameters:
  • dataset (SavableDataset[T]) – The dataset to load.

  • worker_config – The worker config to use

  • checkpoint_every_sec (float) – This is the time in seconds after which a checkpoint is saved. It may take the same duration to restore a checkpoint, but introduces additional overhead during reading data from the dataset, so this should be chosen accordingly. Only applies if using workers.

  • checkpoint_every_min_n_samples (int | None) – Overwrites the minimum number of samples between checkpoints. Defaults to number of workers * 2. Only applies if using workers.

  • n_checkpoints (int | None) – The number of checkpoints to keep in memory. Only applies if using workers. If None, computes a suitable value.

  • gc_collect_every_n_steps (int) – The number of steps after which the garbage collector is called. As we’re usually handling large (but few) tensors here, and the python garbage collection is already full of objects just by importing, this can improve the memory footprint quite a lot, and may even be necessary to avoid memory overflow.

  • gc_freeze_at_start (bool) – If true, the garbage collector is frozen at the start of the worker processes. This improves the garbage collection performance by a lot. In rare cases, this may cause issues and can be disabled. Keep enabled if you experience no issues.

  • cache_pool (CachePool | None) – If set, the cache pool to use for the dataset.

  • watchdog_timeout_seconds (float | None) – The timeout in seconds. If None, the watchdog is disabled.

  • watchdog_initial_timeout_seconds (float | None) – The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.

  • fail_on_timeout (bool) – If True, stops the whole process upon timeout, after printing a stack trace.

  • prefetch_factor (int)

can_restore_sample()[source]
Return type:

bool

cmd_queues: List[Queue]

The queues used to send commands to the workers

config()[source]

Get the configuration, which defines the dataset. Useful in conjunction with save_state and restore_state to match the configuration as well.

dataset: SavableDatasetWrapper[T] | SimpleSavableDatasetWrapper[T]

The wrapped dataset. For multiprocessing, this is a megatron.energon.SavableDatasetWrapper

id: int = 0

Class instance id

static next_id()[source]
Return type:

int

restore_sample(restore_key)[source]

Restores a sample from a key. This is useful to debug the dataset.

Parameters:

restore_key (Tuple[str | int | tuple, ...])

Return type:

T

restore_state(state)[source]

Deprecated. Use restore_state_global (or restore_state_rank) instead.

Parameters:

state (Sequence[SavableDataLoaderState | None] | None)

Return type:

None

restore_state_global(state, *, src_rank=None)[source]

Restores the saved state from save_state_global (in torch distributed setup). The global state needs be loaded on every rank that has a data loader instance.

Optionally, one can specify a src_rank and only provide the state once. In case of multiple data parallel groups, you must provide the state once in each data parallel group. In this case the src_rank is the rank within the data parallel group.

Parameters:
  • state (Sequence[SavableDataLoaderState | None] | None) – The state to restore, as saved by save_state_global.

  • src_rank (int | None) – The rank from which the state is broadcasted (within the data parallel group, if using DP groups).

Return type:

None

restore_state_rank(state)[source]

Restores the saved state for the current rank.

Parameters:

state (SavableDataLoaderState | None) – The state to restore, as saved by save_state_rank.

Return type:

None

result_queues: List[Queue]

The queues used to receive results from the workers

save_state(dst_rank)[source]

Deprecated. Use save_state_global (or save_state_rank) instead.

Parameters:

dst_rank (int)

Return type:

Sequence[SavableDataLoaderState | None] | None

save_state_global(global_dst_rank)[source]

Saves the state of the dataset globally, collecting the state from all ranks using torch distributed. Allows for restoring the state later using restore_state_global, given the result of this method. Typical scenario: Save the state to disk only on the dst_rank, the other ranks do not save the state. Later, restore the state either only loaded on the dst_rank or loading on all ranks separately using restore_state_global.

Note: If you want to save/restore the state per rank separately, use save_state_rank and the corresponding restore_state_rank. Also, these do not rely on torch distributed.

Parameters:

global_dst_rank (int) – The state will be gathered to this rank. The rank refers to the global rank, not the rank within the data parallel group.

Returns:

The state of the dataset (or None, if not on dst_rank).

Return type:

Sequence[SavableDataLoaderState | None] | None

save_state_rank()[source]

Saves the state of the dataset for the current rank. Allows for restoring the state later using restore_state_rank, given the result of this method.

Returns:

The state of the dataset.

Return type:

SavableDataLoaderState | None

worker_config: WorkerConfig

The worker config

Parameters:
  • dataset (SavableDatasetWrapper[T] | SimpleSavableDatasetWrapper[T])

  • checkpoint_every_sec (float)

  • checkpoint_every_min_n_samples (int | None)

  • n_checkpoints (int | None)

  • gc_collect_every_n_steps (int)

  • gc_freeze_at_start (bool)

  • prefetch_factor (int | None)

  • cache_pool (CachePool | None)

  • watchdog_timeout_seconds (float | None)

  • watchdog_initial_timeout_seconds (float | None)

  • fail_on_timeout (bool)

class megatron.energon.SavableDataset(worker_config)[source]

Bases: IterableDataset[T_sample], Savable, Generic[T_sample], ABC

A dataset that can be saved and restored (i.e. the random state, internal buffers, etc.). I.e. it can be resumed from a checkpoint.

How dataset state saving works:

  1. The dataset state needs to be saved in all forked worker processes which contain a copy of the main dataset instance (see megatron.energon.SavableDataLoader). Each worker returns only its own state.

  2. The main process merges the states via the megatron.energon.SavableDataset.merge_states() method in the main process on the main dataset instance (which doesn’t hold the worker states, as they were forked).

  3. The main process saves the merged state to the checkpoint.

Parameters:

worker_config (WorkerConfig)

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

abstractmethod config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

reset_state_deep()[source]

Resets the state of the dataset to the initial state. Can only be called in a worker process.

Return type:

None

abstractmethod reset_state_own()[source]

Resets the state of the dataset to the initial state. Can only be called in a worker process.

Return type:

None

restore_sample(restore_key)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

restore_key (Tuple[str | int | tuple, ...])

Return type:

T_sample

restore_state(state)[source]

Restores the state of the dataset. This will restore the state of all fields in the _savable_fields tuple. Can only be called in a worker process.

Parameters:

state (FlexState) – The state of the dataset as savable object. If None, restore initial state.

Return type:

None

save_state()[source]

Saves the state of the dataset. This will save and return the state of all fields in the _savable_fields tuple. Can only be called in a worker process.

Return type:

FlexState

worker_config: WorkerConfig
abstractmethod worker_has_samples()[source]

Returns True if the worker’s split has samples. This is used to determine if this dataset yields anything.

Return type:

bool

class megatron.energon.ShuffleBufferDataset(dataset, size, *, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

Shuffle buffer for the dataset.

Parameters:
__init__(dataset, size, *, worker_config)[source]

Create a shuffle buffer for the dataset.

Parameters:
config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(restore_key)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

restore_key (Tuple[str | int | tuple, ...])

Return type:

T_sample

size: int
class megatron.energon.SimilarityInterleavedSample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, images, texts, audio=None, video=None, similarity_matrix=None, matched_text_indices=None)[source]

Bases: Sample

Sample type for interleaved media such as text with images, but without image-text alignment. That alignment has to be assigned from the similarity matrix.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • images (List[Tensor])

  • texts (List[str])

  • audio (List[Tensor] | None)

  • video (List[Tensor] | None)

  • similarity_matrix (Tensor | None)

  • matched_text_indices (List[int] | None)

audio: List[Tensor] | None

The optional audio samples of the sequence

images: List[Tensor]

The images of the sequence

matched_text_indices: List[int] | None

The index within texts representing the sentence that this image is matched to

similarity_matrix: Tensor | None

Similarity matrix between image and text entries in the sequence

texts: List[str]

The texts of the sequence

video: List[Tensor] | None

The optional video frames of the sequence

class megatron.energon.SimilarityInterleavedWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[SimilarityInterleavedSample]

Parameters:

path (EPath)

exception megatron.energon.SkipSample[source]

Bases: Exception

Exception to raise in the map_fn to skip a sample.

class megatron.energon.SourceInfo(*, dataset_path, index, shard_name, file_names)[source]

Bases: object

Information about the source of a sample, i.e. where the data was loaded from.

Parameters:
  • dataset_path (EPath)

  • index (str | int)

  • shard_name (str)

  • file_names (tuple[str, ...])

dataset_path: EPath

The path to the dataset

file_names: tuple[str, ...]

The names of the files in the shard used to create the sample

index: str | int

The index of the sample in the dataset

shard_name: str

The name of the shard tar file

class megatron.energon.StandardWebdatasetFactory(path, *, sample_type, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[T_sample], Generic[T_sample]

This dataset sample loader factory uses the sample type e.g. given from a dataset.yaml, and applies the default loading logic, which includes decoding images, videos and containers.

Parameters:
  • path (EPath)

  • sample_type (Type[T_sample])

__init__(path, *, sample_type, **kwargs)[source]

Factory for the standard webdataset sample loader.

Parameters:
  • path (EPath) – Path to the dataset (passed to parent)

  • sample_type (Type[T_sample]) – Type of the sample to be loaded

  • auto_decode – If true, use the default webdataset sample decoder.

  • image_decode – This defines the decoding results.

  • ignore_decoder_errors – If true, ignore errors when decoding.

  • subflavors – Subflavors dictionary to set for all loaded samples.

  • field_map – Mapping from the webdataset fields to the sample fields.

  • sample_loader – Function to load the sample from the webdataset fields. May be a string in order to load a function from a module, or a callable directly.

  • part_filter – Filter for the parts to load. May be a string in order to load a function from a module, or a callable directly.

  • split_part – Which part to load (e.g. ‘train’, ‘val’, ‘test’).

  • training – If true, apply shuffling and loop the dataset.

  • worker_config – Configuration for the workers.

  • shuffle_over_epochs – Only effective if training=True. How many epochs to shuffle over if training. If = 1, every sample is seen exactly once per epoch. If > 1, samples (or rather shard slices) are shuffled within this number of epochs (i.e. randomly selected without replacement). If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices are drawn with replacement).

  • parallel_shard_iters – Number of parallel opened shards per worker, shuffling between.

  • max_samples_per_sequence – Maximum number of samples per sequence (=how many samples will be sequentially iterated).

  • split_config – Config file to use for shard split definitions.

  • handler – Exception handler. Args: (exception, key).

class megatron.energon.SystemFileStore(base_dir=None)[source]

Bases: FileStore[bytes]

A FileStore that reads files directly from the file system.

Parameters:

base_dir (EPath | str | None)

__init__(base_dir=None)[source]
Parameters:

base_dir (EPath | str | None) – The base directory to use for relative paths. If None, you should only pass absolute paths to __getitem__.

get_path()[source]

Returns the path to the dataset.

Return type:

str

class megatron.energon.TaskEncoder[source]

Bases: ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]

Base class for task encoders.

Task encoding follows these steps:
  1. Data comes from the dataset

  2. megatron.energon.TaskEncoder.encode_sample() / megatron.energon.TaskEncoder.preencode_sample() is called on each sample

  3. megatron.energon.TaskEncoder.select_samples_to_pack() is called on the buffer of samples

  4. megatron.energon.TaskEncoder.postencode_sample() is called on each sample of the current pack

  5. megatron.energon.TaskEncoder.pack_selected_samples() is called on the selected sample pack

  6. megatron.energon.TaskEncoder.batch() is called on the list of encoded samples

  7. megatron.energon.TaskEncoder.encode_batch() is called on the batch

  8. yield to main process

  9. megatron.energon.Batch.to_device() is called on the encoded batch

  10. resulting encoded batch is passed to the network

batch(samples)[source]

Move a batch to a device. May raise megatron.energon.SkipSample to skip a batch.

Parameters:

samples (List[T_encoded_sample])

Return type:

T_raw_batch

batch_group_criterion(sample)[source]

Return a group criterion for the sample. Default implementation does not group (effectively, it returns a single value (None, None), thus only one group is used). Returns the key of the bucket to put this sample into, and the size of the bucket (=batch size). The bucket size must always be the same for the same bucket key.

May raise megatron.energon.SkipSample to skip a batch.

Parameters:

sample (T_encoded_sample)

Return type:

Tuple[Hashable, int | None]

build_batch(dataset, *, batch_size, batch_drop_last=False, packing_buffer_size=None, worker_config)[source]

Applies the batcher to the dataset.

Parameters:
  • dataset (SavableDataset[T_encoded_sample])

  • batch_size (int | None)

  • batch_drop_last (bool)

  • packing_buffer_size (int | None)

  • worker_config (WorkerConfig)

Return type:

SavableDataset[T_raw_batch]

build_cook_crude_sample(dataset, *, worker_config, subflavors, get_primary_aux, aux=None)[source]

Applies the sample cooker to the dataset if we have cookers registered.

Parameters:
Return type:

SavableDataset[T_sample]

build_encode_sample(dataset, *, worker_config)[source]

Applies the sample encoder to the dataset.

Parameters:
Return type:

SavableDataset[T_encoded_sample]

build_train_datasets(*, datasets, worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, virtual_epoch_length=0, shuffle_buffer_size=None, blend_mode=DatasetBlendMode.NONE, repeat=True)[source]

Combines train datasets to a single dataset.

Parameters:
  • datasets (List[LoadedDataset])

  • worker_config (WorkerConfig)

  • batch_size (int | None)

  • batch_drop_last (bool)

  • packing_buffer_size (int | None)

  • virtual_epoch_length (int)

  • shuffle_buffer_size (int | None)

  • blend_mode (DatasetBlendMode)

  • repeat (bool)

Return type:

SavableDataset[T_batch]

build_val_datasets(*, datasets, worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, limit=None)[source]

Combines val datasets to a single dataset.

Parameters:
  • datasets (List[LoadedDataset])

  • worker_config (WorkerConfig)

  • batch_size (int)

  • batch_drop_last (bool)

  • packing_buffer_size (int | None)

  • limit (int | None)

Return type:

SavableDataset[T_batch]

property cache: CachePool

Returns the cache pool to use for caching out sample data to disk (for use with cookers / aux file stores). This is set and configured externally by the loader.

cook_crude_sample(sample, get_primary_aux, **aux)[source]

Cooks a crude sample.

Parameters:
  • sample (T_sample | CrudeSample) – The sample to cook.

  • get_primary_aux (Callable[[], FileStore]) – A function that returns the (cached) primary auxiliary dataset.

  • **aux (FileStore) – The auxiliary side dishes to use for cooking.

Return type:

T_sample

Returns: The cooked sample.

cookers: Sequence[Cooker[T_sample]] = ()
property current_batch_index: int

Returns the current index for the next batch yielded from the current worker. Each batch on the current rank will get a strictly increasing unique number. Counting happens on each rank separately (i.e. each rank will get the same numbers for same batch index).

property current_sample_index: int

Returns the current index for the next sample yielded from the current routine (e.g. for encode_sample, batch, or encode_batch). Each routine will get a number representing the number of calls to that function. Across workers, this number will be unique, but it is not synced across workers, thus it may raise in different intervals (e.g. if batching does not work the same for all batches). When restoring a sample, this number is also restored and can be relied on for deterministic randomness reproduction of a sample.

decoder: SampleDecoder | None = <megatron.energon.flavors.webdataset.sample_decoder.SampleDecoder object>

The decoder to use for decoding samples. Set manually as needed to override options.

Parameters:

sample (dict)

Return type:

dict

encode_batch(batch)[source]

Encode a batch of samples. May raise megatron.energon.SkipSample to skip a batch. Alternatively, this can be a generator that yields (or ignores) new batches.

Parameters:

batch (T_raw_batch)

Return type:

T_batch | Generator[T_batch, None, None]

encode_sample(sample)[source]

Encode a single sample. May raise megatron.energon.SkipSample to skip a sample. Alternatively, this can be a generator that yields (or ignores) new samples. If this is defined, preencode_sample() and postencode_sample() must not be defined.

Parameters:

sample (T_sample)

Return type:

T_encoded_sample | Generator[T_encoded_sample, None, None]

pack_selected_samples(samples)[source]

Given one set of samples to pack, returns the final packed sample. Packing is only active when packing_buffer_size is set. Internally this stage is called “final_packing”.

Parameters:

samples (List[T_encoded_sample]) – The samples to pack into a single sample

Return type:

T_encoded_sample

Returns: The final packed sample.

postencode_sample(sample)[source]

Post-encode a single sample. May raise megatron.energon.SkipSample to skip a sample. Alternatively, this can be a generator that yields (or ignores) new samples. Use in conjunction with packing and caching. If this is defined, encode_sample() must not be defined.

Parameters:

sample (T_sample)

Return type:

T_encoded_sample | Generator[T_encoded_sample, None, None]

preencode_sample(sample)[source]

Pre-encode a single sample. May raise megatron.energon.SkipSample to skip a sample. Alternatively, this can be a generator that yields (or ignores) new samples. Use in conjunction with packing and caching. If this is defined, encode_sample() must not be defined.

Parameters:

sample (T_sample)

Return type:

T_sample | Generator[T_sample, None, None]

select_samples_to_pack(samples)[source]

For packing, selects the samples to be packed together. Packing is only active when packing_buffer_size is set. Internally this stage is called “pre_packing”.

Parameters:

samples (List[T_encoded_sample]) – The samples to pre-pack. A full buffer will be passed into the function.

Return type:

List[List[T_encoded_sample]]

Returns: The pre-packed samples as a list of lists of samples.

class megatron.energon.TextSample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, text)[source]

Bases: Sample

Sample type for simple text.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • text (str)

text: str

The text of the sample

class megatron.energon.TextWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[TextSample]

Parameters:

path (EPath)

class megatron.energon.VQAOCRWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[VQAOCRSample]

Parameters:

path (EPath)

class megatron.energon.VQASample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, image, context, answers=None, answer_weights=None)[source]

Bases: Sample

Sample type for visual question answering.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • image (Tensor)

  • context (str)

  • answers (List[str] | None)

  • answer_weights (Tensor | None)

answer_weights: Tensor | None

The weights of the possible answers. Optionally available.

answers: List[str] | None

The possible answers. Not set for testing.

context: str

The context/question for the image

image: Tensor

The input image tensor in the shape (C, H, W)

class megatron.energon.VQAWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[VQASample]

Parameters:

path (EPath)

class megatron.energon.VidQASample(*, __key__, __restore_key__, __subflavors__=None, __sources__=None, video, context, answers=None, answer_weights=None)[source]

Bases: Sample

Sample type for video question answering.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavors__ (Dict[str, Any] | None)

  • __sources__ (tuple[SourceInfo, ...] | None)

  • video (AVDecoder)

  • context (str)

  • answers (List[str] | None)

  • answer_weights (Tensor | None)

answer_weights: Tensor | None

The weights of the possible answers. Optionally available.

answers: List[str] | None

The possible answers. Not set for testing.

context: str

The context/question for the image.

video: AVDecoder

The video data containing the image and audio info.

class megatron.energon.VidQAWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[VidQASample]

Parameters:

path (EPath)

class megatron.energon.WorkerConfig(*, rank, world_size, num_workers, data_parallel_group=None, seed_offset=0, worker_debug_path=None, worker_log_level=0, _worker_debug_file=None, _worker_debug_file_worker_id=None)[source]

Bases: object

Provides information about the current worker and the global configuration. This gives each data parallel rank its proper config. Every rank (up to world_size-1) must be used. If set wrong, the datasets might yield the same data or data might be missing, as data is split over the data parallel ranks with this config! You may set the same rank, if you need multiple ranks to retrieve the same data.

Parameters:
  • rank (int)

  • world_size (int)

  • num_workers (int)

  • data_parallel_group (ProcessGroup | None)

  • seed_offset (int)

  • worker_debug_path (str | None)

  • worker_log_level (int)

  • _worker_debug_file (TextIO | None)

  • _worker_debug_file_worker_id (int | None)

property active_worker_batch_index: int

Returns the current batch index for the actively iterating worker.

active_worker_config: ClassVar[WorkerConfig | None] = None

The current worker config within the current iterating worker

property active_worker_sample_index: int

Returns the current sample index for the actively iterating worker.

assert_worker()[source]

Checks if the current process is a worker (if configured so), and that the workers are properly configured.

config()[source]
Return type:

Dict[str, Any]

data_parallel_group: ProcessGroup | None

If not using all ranks for data parallel, set this to the corresponding group.

static default_worker_config(num_workers=4, data_parallel_group=None)[source]

Returns the default worker config using torch distributed if available. If torch distributed is not available, a single local rank is assumed.

Parameters:
  • num_workers (int)

  • data_parallel_group (ProcessGroup | None)

Return type:

WorkerConfig

global_rank()[source]

Returns the global rank of this worker config but as a global rank, not as a rank within the data parallel group.

Return type:

int

global_worker_id(override_local_worker_id=None)[source]

Returns the global worker index by multiplying the rank with the number of workers. Alternatively, you can override the local worker id.

Parameters:

override_local_worker_id (int, optional) – The local worker id to override. None means the current worker, which is the default.

Return type:

int

num_workers: int

The number of workers per rank. May be 0 to disable worker processes.

rank: int

The data parallel rank/id of the current process.

rank_worker_id()[source]

Returns the self worker id within the current rank.

Return type:

int

seed_offset: int
should_log(level)[source]
Parameters:

level (int)

Return type:

bool

worker_activate(sample_index, override_global_rank=None, cache_pool=None)[source]

Activates the worker config for the current worker and sets it as actively iterating. Must be called before next() call on the datasets.

Parameters:
  • sample_index (int)

  • override_global_rank (int | None)

  • cache_pool (CachePool | None)

worker_deactivate()[source]

Deactivates the worker config for the current worker and deactivates it for iterating. Must be called after next() call on the datasets.

worker_debug_path: str | None
worker_id_offset: ClassVar[int] = 0
worker_log(data)[source]

Logs the given data to the worker debug file.

Parameters:

data (dict)

Return type:

None

worker_log_level: int

Log level for worker logging.

worker_pop_sample_index()[source]

Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.

worker_push_sample_index(sample_index)[source]

Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.

Parameters:

sample_index (int)

worker_seed(override_local_worker_id=None)[source]

Returns the seed for the current worker (or a specified worker). Base on the current worker id and the seed offset, compute a seed. Alternatively, you can override the local worker id with a fixed one to pregenerate seeds for multiple workers.

Parameters:

override_local_worker_id (int, optional) – The local worker id to override. None means the current worker, which is the default.

Return type:

int

world_size: int

The total number of data parallel processes.

megatron.energon.basic_sample_keys(crude_sample, additional_source_info=())[source]

A convenience helper to extract the basic keys from a crude sample, which you will always need to forward to the cooked sample.

Parameters:
  • crude_sample (dict)

  • additional_source_info (tuple[SourceInfo, ...])

Return type:

dict

megatron.energon.batch_list(batch)[source]

Stack a batch of tensors padded with 0s.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.batch_pad_stack(batch)[source]

Stack a batch of arbitrary-sized tensors padded with 0s.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.batch_stack(batch)[source]

Stack a batch of tensors.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.concat_pad(batch)[source]

Concat a batch of arbitrary-sized tensors padded with 0s.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.cooker(fn=None, *, need_cache=False, need_primary=False)[source]

Decorator to mark a function as a cooker, optionally enabling cache and primary dataset arguments.

Parameters:
  • fn (F | None)

  • need_cache (bool)

  • need_primary (bool)

Return type:

F | Callable[[F], F]

megatron.energon.generic_batch(batch)[source]

Based on the types/shapes of the batch: Will either pad and stack, or return as list. Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.generic_concat(batch)[source]

Based on the types/shapes of the batch: Will either pad and stack, or return as list. Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.get_loader(dataset, *, worker_config=None, prefetch_factor=2, cache_pool=None, watchdog_timeout_seconds=60, watchdog_initial_timeout_seconds=None, fail_on_timeout=False)[source]

Get a dataloader for the given dataset.

Parameters:
  • dataset (SavableDataset[T]) – The dataset to create a loader for.

  • worker_config (WorkerConfig | None) – Deprecated. Please pass this to the dataset instead.

  • cache_pool (CachePool | None) – If set, the cache pool to use for the dataset.

  • watchdog_timeout_seconds (float | None) – The timeout in seconds. If None, the watchdog is disabled.

  • watchdog_initial_timeout_seconds (float | None) – The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.

  • fail_on_timeout (bool) – If True, stops the whole process upon timeout, after printing a stack trace.

  • prefetch_factor (int)

Returns:

The instantiated torch.data.DataLoader, yielding batches from the dataset.

Return type:

BasicDataLoader[T]

megatron.energon.get_savable_loader(dataset, *, worker_config=None, checkpoint_every_sec=60, checkpoint_every_min_n_samples=None, n_checkpoints=None, gc_collect_every_n_steps=10, prefetch_factor=2, cache_pool=None, watchdog_timeout_seconds=60, watchdog_initial_timeout_seconds=None, fail_on_timeout=False)[source]

Get a dataloader for the given dataset.

Parameters:
  • dataset (SavableDataset[T]) – The dataset to create a loader for.

  • worker_config (WorkerConfig | None) – Deprecated. Please pass this to the dataset instead.

  • checkpoint_every_sec (float) – This is the time in seconds after which an internal checkpoint is saved. It may take the same duration to restore a checkpoint, but introduces additional overhead during reading data from the dataset, so this should be chosen accordingly. Only applies if using workers.

  • checkpoint_every_min_n_samples (int | None) – Overwrites the minimum number of samples between checkpoints. Defaults to number of workers * 2. Only applies if using workers.

  • n_checkpoints (int | None) – The number of internal checkpoints to keep. Only applies if using workers. If None, computes a suitable value.

  • cache_pool (CachePool | None) – If set, the cache pool to use for the dataset.

  • watchdog_timeout_seconds (float | None) – The timeout in seconds. If None, the watchdog is disabled.

  • watchdog_initial_timeout_seconds (float | None) – The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.

  • fail_on_timeout (bool) – If True, stops the whole process upon timeout, after printing a stack trace.

  • gc_collect_every_n_steps (int)

  • prefetch_factor (int)

Returns:

The instantiated megatron.energon.SavableDataLoader, yielding batches from the dataset, allowing to save the state of the dataset.

Return type:

SavableDataLoader[T]

megatron.energon.get_train_dataset(path, *, split_part='train', worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, shuffle_buffer_size, max_samples_per_sequence, virtual_epoch_length=0, shuffle_over_epochs_multiplier=1, task_encoder=<megatron.energon.task_encoder.base.DefaultTaskEncoder object>, repeat=True, **kwargs)[source]

Get training data loader with sensible defaults. See get_dataset for more details.

The following recipe will be used:
  • megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset)

  • task_encoder.encode_sample

  • (megatron.energon.MixDataset if mixing)

  • megatron.energon.BatchDataset with task_encoder.batch for collation

  • task_encoder.encode_batch

  • megatron.energon.EpochizeDataset (if virtual_epoch_length is set)

Parameters:
  • path (str | EPath | Path) – Path to the dataset.

  • split_part (Literal['train'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • batch_size (int | None) – Size of a batch. If None, do not batch

  • batch_drop_last (bool) – If true, drop the last batch if it is smaller than batch_size.

  • shuffle_buffer_size (int | None) – Size of the sample shuffle buffer (before task encoding).

  • max_samples_per_sequence (int | None) – If set, limit the number of samples per sample-sequence to this.

  • virtual_epoch_length (int) – If set, the dataset will be epochized to this length (=iterating will be suspended and the for-loop returns, next for-loop continues iterating). Otherwise, the dataset will loop indefinitely.

  • shuffle_over_epochs_multiplier (int | None) – Shuffle the shards over this many epochs.

  • task_encoder (TaskEncoder[Any, Any, Any, T]) – Task encoder to use.

  • repeat (bool) – By default, the inner datasets will loop. If set to False, stop iteration after one epoch. Must only be set to False in conjunction with blend_epochized in the metadataset if one is used.

  • cache_pool – If set, the cache pool to use for the dataset.

  • **kwargs – Additional arguments to the dataset constructor.

  • packing_buffer_size (int | None)

Returns:

The dataloader.

Return type:

SavableDataset[T]

megatron.energon.get_val_dataset(path, *, split_part='val', worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, limit=None, task_encoder=<megatron.energon.task_encoder.base.DefaultTaskEncoder object>, **kwargs)[source]

Get the validation/test dataset with sensible defaults. See get_dataset for more details.

The following recipe will be used:
  • megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset)

  • task_encoder.encode_sample

  • (megatron.energon.MixDataset if mixing)

  • megatron.energon.BatchDataset with task_encoder.batch for collation

  • megatron.energon.LimitDataset (if limit is set)

  • task_encoder.encode_batch

Parameters:
  • path (str | EPath | Path) – Path to the dataset.

  • split_part (Literal['val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • batch_size (int) – Size of a batch

  • batch_drop_last (bool) – If true, drop the last batch if it is smaller than batch_size.

  • limit (int | None) – If set, limit the number of batches loaded from the dataset to this.

  • task_encoder (TaskEncoder[Any, Any, Any, T]) – Task encoder to use.

  • **kwargs – Additional arguments to the dataset constructor.

  • packing_buffer_size (int | None)

Returns:

The loaded dataset.

Return type:

SavableDataset[T]

megatron.energon.get_val_datasets(path, *, split_part='val', worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, limit=None, task_encoder=<megatron.energon.task_encoder.base.DefaultTaskEncoder object>, **kwargs)[source]

Get the validation/test dataset with sensible defaults. See get_dataset for more details.

The following recipe will be used:
  • megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset)

  • task_encoder.encode_sample

  • (megatron.energon.MixDataset if mixing)

  • megatron.energon.BatchDataset with task_encoder.batch for collation

  • megatron.energon.LimitDataset (if limit is set)

  • task_encoder.encode_batch

Parameters:
  • path (str | EPath | Path) – Path to the dataset.

  • split_part (Literal['val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • batch_size (int) – Size of a batch

  • batch_drop_last (bool) – If true, drop the last batch if it is smaller than batch_size.

  • limit (int | None) – If set, limit the number of batches loaded from the dataset to this.

  • task_encoder (TaskEncoder[Any, Any, Any, T]) – Task encoder to use.

  • **kwargs – Additional arguments to the dataset constructor.

  • packing_buffer_size (int | None)

Returns:

The loaded val datasets, with the source datasets.

Return type:

List[Tuple[SavableDataset[T], BaseCoreDatasetFactory]]

megatron.energon.homogeneous_concat_mix(samples)[source]

Mixes a list of batches into a single batch. The default implementation is to concat the batches if they are all of the same type, otherwise return a list of batches.

Parameters:

samples (List[T_batch_in]) – THe samples to mix.

Returns:

The mixed batch.

Return type:

T_batch

megatron.energon.load_dataset(path, **kwargs)[source]

Loads a (meta)dataset.

Parameters:

path (str | EPath | Path)

Return type:

DatasetLoaderInterface

megatron.energon.prepare_metadataset(path)[source]
Parameters:

path (EPath)

megatron.energon.stateless(fn=None, *, restore_seeds=False, failure_tolerance=None)[source]

Decorator to mark a function of the task encoder as restorable.

Parameters:
  • fn (Callable[[...], T] | None) – The function to decorate.

  • restore_seeds (bool) – Whether to restore the seeds for the function. I.e. the seeds are set from the sample index and the worker seed, such that they can be restored when a sample is restored from that function.

  • failure_tolerance (int | None) – The number of consecutive exceptions that are handled, after which a FatalSampleError is raised for this function.

Return type:

Callable[[Callable[[…], T]], Callable[[…], T]] | Callable[[…], T]

Usage:

@stateless
def encode_sample(self, sample: T_sample) -> T_encoded_sample:
    ...

# Or if randomness is used (e.g. for augmentations):
@stateless(restore_seeds=True)
def encode_sample(self, sample: T_sample) -> T_encoded_sample:
    ...
class megatron.energon.task_encoder.cooking.Cooker(cook, has_subflavors=None)[source]

Bases: Generic[T_sample]

A cooker transforms a crude sample (simple dict) into a specific sample type inheriting from Sample. The cook method performs the transformation, the other fields are used to select the samples which this cooker can transform. If no filters are provided, the cooker will transform any CrudeSample.

Parameters:
  • cook (Callable[[...], T_sample])

  • has_subflavors (dict | None)

cook: Callable[[...], T_sample]
has_subflavors: dict | None = None
is_match(crude_sample)[source]
Parameters:

crude_sample (CrudeSample)

Return type:

bool

property need_cache: bool
property need_primary: bool
megatron.energon.task_encoder.cooking.basic_sample_keys(crude_sample, additional_source_info=())[source]

A convenience helper to extract the basic keys from a crude sample, which you will always need to forward to the cooked sample.

Parameters:
  • crude_sample (dict)

  • additional_source_info (tuple[SourceInfo, ...])

Return type:

dict

megatron.energon.task_encoder.cooking.cooker(fn: None = None) Callable[[F], F][source]
megatron.energon.task_encoder.cooking.cooker(*, need_cache: bool = False, need_primary: bool = False) Callable[[F], F]

Decorator to mark a function as a cooker, optionally enabling cache and primary dataset arguments.

megatron.energon.task_encoder.cooking.get_cooker_need_cache(fn)[source]

Get whether a function is a cooker.

Parameters:

fn (Callable[[...], T_sample])

Return type:

bool

megatron.energon.task_encoder.cooking.get_cooker_need_primary(fn)[source]

Get whether a function is a cooker.

Parameters:

fn (Callable[[...], T_sample])

Return type:

bool

megatron.energon.av

class megatron.energon.av.AVData(*, video_clips: list[torch.Tensor] | None, video_timestamps: list[tuple[float, float]] | None, audio_clips: list[torch.Tensor] | None, audio_timestamps: list[tuple[float, float]] | None)[source]

Bases: object

Parameters:
  • video_clips (list[Tensor] | None)

  • video_timestamps (list[tuple[float, float]] | None)

  • audio_clips (list[Tensor] | None)

  • audio_timestamps (list[tuple[float, float]] | None)

audio_clips: list[Tensor] | None

A list of audio clips in the shape (channels, samples)

audio_timestamps: list[tuple[float, float]] | None

The timestamps for the audio clips. List of tuples (start, end) in seconds

video_clips: list[Tensor] | None

A list of video clips in the shape (frames, channels, h, w)

video_timestamps: list[tuple[float, float]] | None

The timestamps for the video clips. List of tuples (start, end) in seconds

class megatron.energon.av.AVDecoder(stream, suppress_warnings=False)[source]

Bases: object

A class that provides a flexible interface for decoding audio and video data.

This class allows users to control decoding parameters at runtime rather than having them fixed during initialization. It’s particularly useful for cases where different samples may need different decoding parameters.

Parameters:
  • stream (BinaryIO)

  • suppress_warnings (bool)

get_audio()[source]

Get the entire audio data from the stream.

Return type:

AVData

get_audio_clips(audio_clip_ranges, audio_unit='seconds')[source]

Get audio clips from the audio stream.

Parameters:
  • audio_clip_ranges (Sequence[tuple[float, float]]) – List of audio clip start and end positions in the given unit (see audio_unit)

  • audio_unit (Literal['samples', 'seconds']) – Unit of the audio clip positions (“samples” for sample number, “seconds” for timestamp)

Returns:

  • audio_clips: List of audio clips

  • audio_clips_timestamps: List of timestamps for each audio clip start and end in seconds

Return type:

A tuple containing

get_audio_duration()[source]

Get the duration of the audio stream.

Returns:

The duration of the audio stream in seconds

Return type:

float | None

get_audio_samples_per_second()[source]

Get the number of samples per second of the audio stream.

Return type:

int

get_clips(video_clip_ranges=None, audio_clip_ranges=None, video_unit='seconds', audio_unit='seconds', video_out_frame_size=None)[source]

Get clips from the video and/or audio streams. Given a list of (start, end) tuples, this method will decode the video and/or audio clips at the specified start and end times. The units of the start and end times are specified by the video_unit and audio_unit arguments.

Parameters:
  • video_clip_ranges (Sequence[tuple[float, float]] | None) – List of video clip start and end positions in the given unit (see video_unit)

  • audio_clip_ranges (Sequence[tuple[float, float]] | None) – List of audio clip start and end positions in the given unit (see audio_unit)

  • video_unit (Literal['frames', 'seconds']) – Unit of the video clip positions (“frames” for frame number, “seconds” for timestamp)

  • audio_unit (Literal['samples', 'seconds']) – Unit of the audio clip positions (“samples” for sample number, “seconds” for timestamp)

  • video_out_frame_size (tuple[int, int] | None) – Output size for video frames (width, height), or None to use the original frame size

Returns:

AVData containing the decoded video and audio clips

Return type:

AVData

get_frames(video_decode_audio=False)[source]

Decode the audio/video data with the specified parameters.

Parameters:
  • audio_clip_duration – Duration of each audio clip in seconds

  • audio_num_clips – Number of audio clips to extract (-1 for all)

  • video_decode_audio (bool) – Whether to decode audio from video

  • video_num_frames – Number of video frames to extract

  • video_out_frame_size – Output size for video frames (width, height)

Returns:

VideoData containing the decoded frames and metadata, or None if decoding failed The video tensor is in the shape (frames, channels, height, width) The audio tensor is in the shape (channels, samples)

Return type:

AVData | None

get_video()[source]

Get the entire video data from the stream (without audio).

Return type:

AVData

get_video_clips(video_clip_ranges, video_unit='seconds', video_out_frame_size=None)[source]

Get video clips from the video stream.

Parameters:
  • video_clip_ranges (Sequence[tuple[float, float]]) – List of video clip start and end positions in the given unit (see video_unit)

  • video_unit (Literal['frames', 'seconds']) – Unit of the video clip positions (“frames” for frame number, “seconds” for timestamp)

  • video_out_frame_size (tuple[int, int] | None) – Output size for video frames (width, height), or None to use the original frame size

Returns:

  • video_clips: List of video clips

  • video_clips_timestamps: List of timestamps for each video clip start and end in seconds

Return type:

A tuple containing

get_video_duration(get_frame_count=False)[source]

Get the duration of the video stream.

Parameters:

get_frame_count (bool) – Whether to return the number of frames in the video. This is a more costly operation.

Returns:

A tuple containing the duration in seconds, and the number of frames in the video

Return type:

tuple[float | None, int | None]

get_video_fps()[source]

Get the FPS of the video stream.

Return type:

float

get_video_with_audio()[source]

Get the entire video and audio data from the stream.

Return type:

AVData

has_audio_stream()[source]

Check if the stream has an audio stream.

Return type:

bool

has_video_stream()[source]

Check if the stream has a video stream.

Return type:

bool

seeker: Fastseek
stream: BinaryIO
suppress_warnings: bool
class megatron.energon.av.AVWebdatasetDecoder(video_decode_audio, av_decode='AVDecoder')[source]

Bases: object

A decoder class for audio and video data that provides a consistent interface for decoding media files.

This class encapsulates the decoding parameters and provides a callable interface that can be used with webdataset or other data loading pipelines. It supports both video and audio decoding with configurable parameters for frame extraction, resizing, and audio clip extraction.

Parameters:
  • video_decode_audio (bool) – Whether to decode audio from video files. If True, audio will be extracted alongside video frames.

  • av_decode (Literal['torch', 'AVDecoder', 'pyav']) – If “AVDecoder”, returns an AVDecoder instance for flexible decoding. If “torch”, returns decoded VideoData.

Example

>>> decoder = AVWebdatasetDecoder(
...     video_decode_audio=True,
...     av_decode="AVDecoder"
... )
>>> result = decoder("video.mp4", video_bytes)
read_av_data(key, data)[source]

Decoder function that returns an AVData object for flexible decoding.

Parameters:
  • key (str) – The file extension or key

  • data (bytes) – The raw bytes of the media file

Returns:

AVData object that can be used to decode the media with custom parameters

Return type:

AVDecoder

megatron.energon.av.get_clips_uniform(av_decoder, clip_duration_seconds, num_clips, request_video=False, request_audio=False, video_out_frame_size=None)[source]

Extracts a sequence of clips, such that each clip is of equal duration and the clips are equidistant from each other.

Parameters:
  • av_decoder (AVDecoder) – An AVDecoder instance.

  • clip_duration_seconds (float) – The duration of each clip in seconds.

  • num_clips (int) – The number of clips to extract.

  • request_video (bool) – Whether to request video clips.

  • request_audio (bool) – Whether to request audio clips.

  • video_out_frame_size (tuple[int, int] | None) – The size of the video frames to output, or None to use the original size.

Returns:

An AVData object containing the extracted video and audio clips.

Return type:

AVData

megatron.energon.av.get_single_frames_uniform(av_decoder, num_frames, *, video_out_frame_size=None, return_timestamps=False)[source]

Extracts a sequence of clips, such that each clip contains only a single frame and the frames are equidistant from each other.

Parameters:
  • av_decoder (AVDecoder) – An AVDecoder instance.

  • num_frames (int) – The number of frames to extract.

  • video_out_frame_size (tuple[int, int] | None) – The size of the video frames to output, or None to use the original size.

  • return_timestamps (bool)

Returns:

A tensor of shape (num_frames, channels, height, width) containing the extracted frames.

Return type:

Tensor | tuple[Tensor, list[float]]