megatron.energon

class megatron.energon.AugmentTaskEncoder(task_encoder)[source]

Bases: TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch], Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]

Augment a task encoder with additional functionality. By default, delegates everything to the original task encoder.

Parameters:

task_encoder (TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch])

batch(samples)[source]

Move a batch to a device. May raise megatron.energon.SkipSample to skip a batch.

Parameters:

samples (List[T_encoded_sample])

Return type:

T_raw_batch

encode_batch(batch_data)[source]

Encode a batch of samples. May raise megatron.energon.SkipSample to skip a batch. Alternatively, this can be a generator that yields (or ignores) new batches.

Parameters:

batch_data (T_raw_batch)

Return type:

T_batch

encode_sample(sample)[source]

Encode a single sample. May raise megatron.energon.SkipSample to skip a sample. Alternatively, this can be a generator that yields (or ignores) new samples.

Parameters:

sample (T_sample)

Return type:

T_encoded_sample

class megatron.energon.BaseCoreDatasetFactory[source]

Bases: Generic[T_sample], ABC

Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for joining in a joined dataset.

abstract build(worker_rotation_offset=0)[source]
Parameters:

worker_rotation_offset (int)

Return type:

SavableDataset[T_sample]

paths: List[EPath]
subflavor: str | None
subflavors: Dict[str, Any]
class megatron.energon.BaseWebdatasetFactory(path, *, split_part, training, worker_config, shuffle_over_epochs=1, parallel_shard_iters=None, max_samples_per_sequence=None, info_config='.info.yaml', split_config='split.yaml', part_filter=None, handler=<function reraise_exception>)[source]

Bases: BaseCoreDatasetFactory[T_sample], WebdatasetPreparator, Sharder, ErrorHandler, Generic[T_sample], ABC

Base class for all webdataset sample loader factories. Applies proper sharding across workers.

Parameters:
  • path (EPath)

  • split_part (str)

  • training (bool)

  • worker_config (WorkerConfig)

  • shuffle_over_epochs (int | None)

  • parallel_shard_iters (int | None)

  • max_samples_per_sequence (int | None)

  • info_config (str)

  • split_config (str)

  • part_filter (Callable[[str], bool] | None)

  • handler (Callable[[Exception, str | None], None])

build(worker_rotation_offset=0)[source]
Parameters:

worker_rotation_offset (int)

Return type:

SavableDataset[T_sample]

config()[source]
Return type:

Dict[str, Any]

abstract load_sample(raw_data)[source]

Loads the sample from the dataset.

Parameters:

raw_data (FilteredSample)

Return type:

T_sample

path: EPath
rank_shards: List[List[Sequence[ShardInfo]]]
sample_filter(key)[source]
Parameters:

key (str)

Return type:

bool

shards: List[ShardInfo]
training: bool
worker_config: WorkerConfig
class megatron.energon.Batch[source]

Bases: PinMemoryMixin

Base class for a batch dataclass. Provides a default implementation for pinning memory.

class megatron.energon.BatchDataset(dataset, batch_size, batcher, *, batcher_stateless=False, batcher_config=None, drop_last=False, error_handler=<function log_exception>, worker_config)[source]

Bases: BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]

This dataset wrapper transforms a dataset of samples into a dataset of batches.

Parameters:
  • dataset (SavableDataset[T_batch_sample])

  • batch_size (int)

  • batcher (Callable[[List[T_batch_sample]], T_batch])

  • batcher_stateless (bool)

  • batcher_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • drop_last (bool)

  • error_handler (Callable[[Exception, List[T_batch_sample]], None])

  • worker_config (WorkerConfig)

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

batch_size: int
batcher: Callable[[List[T_batch_sample]], T_batch]
can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

drop_last: bool
error_handler: Callable[[Exception, List[T_batch_sample]], None]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(index)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

index (Tuple[str | int | tuple, ...])

Return type:

T_batch

class megatron.energon.BlendDataset(*dataset_weights, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample]

This dataset wrapper blends multiple iterable datasets together give a weighting. The datasets may be infinite. This dataset is always infinite.

Parameters:
config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

exhausted: List[bool]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

weights: Tuple[float, ...]
class megatron.energon.CaptioningSample(__key__, __restore_key__, __subflavor__, __subflavors__, image, caption)[source]

Bases: Sample

Sample type for image captioning.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • image (Tensor)

  • caption (str)

caption: str

The caption string

image: Tensor

The input image tensor in the shape (C, H, W)

class megatron.energon.CaptioningWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[CaptioningSample]

Parameters:

path (EPath)

class megatron.energon.ConcatDataset(*datasets, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

This dataset wrapper concatenates multiple iterable datasets together. The datasets must be finite, otherwise not all datasets can be sampled. This is only useful for validation / test datasets.

Parameters:
config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.Cooker(cook, is_subflavor=None, has_subflavors=None, condition=None)[source]

Bases: object

A cooker transforms a crude sample (simple dict) into a specific sample type inheriting from Sample.

The cook method performs the transformation, the other fields are used to select the samples which this cooker can transform. If no filters are provided, the cooker will transform any sample.

Parameters:
  • cook (Callable[[dict], Sample])

  • is_subflavor (str | None)

  • has_subflavors (dict | None)

  • condition (Callable[[dict], bool] | None)

condition: Callable[[dict], bool] | None = None
cook: Callable[[dict], Sample]
has_subflavors: dict | None = None
is_match(crude_sample)[source]
Parameters:

crude_sample (CrudeSample)

Return type:

bool

is_subflavor: str | None = None
class megatron.energon.CrudeSample[source]

Bases: dict

Generic sample type to be processed later.

class megatron.energon.CrudeWebdataset(path, *, subflavor=None, subflavors=None, part_filter=<function CrudeWebdataset.<lambda>>, auto_decode=True, image_decode='torchrgb', ignore_decoder_errors=False, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[CrudeSample]

The CrudeWebdataset is used to load crude / raw samples and decode them in the user code using so-called cookers.

See the documentation under “Crude Data” for more information.

Parameters:
  • path (EPath)

  • subflavor (str | None)

  • subflavors (Dict[str, Any] | None)

  • part_filter (str | List[str] | Callable[[str], bool])

  • auto_decode (bool)

  • image_decode (Literal['l8', 'rgb8', 'rgba8', 'l', 'rgb', 'rgba', 'torchl8', 'torchrgb8', 'torchrgba8', 'torchl', 'torchrgb', 'torch', 'torchrgba', 'pill', 'pil', 'pilrgb', 'pilrgba'])

  • ignore_decoder_errors (bool)

class megatron.energon.DatasetLoader(path, split_part=None, subflavor=None, subflavors=None, shuffle_over_epochs_multiplier=1, dataset_config='dataset.yaml', split_config='split.yaml')[source]

Bases: DatasetLoaderInterface

Loads a dataset from a path.

Parameters:
  • path (str | EPath)

  • split_part (str | None)

  • subflavor (str | None)

  • subflavors (Dict[str, Any] | None)

  • shuffle_over_epochs_multiplier (int | None)

  • dataset_config (str)

  • split_config (str)

dataset_config: str
get_dataset(*, training, split_part=None, worker_config, subflavor=None, subflavors=None, shuffle_over_epochs=1, split_config=None, dataset_config=None, **kwargs)[source]
Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (str | None) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration.

  • shuffle_buffer_size – Size of the sample shuffle buffer (before task encoding).

  • subflavor (str | None) – Subflavor to use, might be overridden by inner datasets.

  • subflavors (Dict[str, Any] | None) – Subflavors to use, might be overridden by inner datasets.

  • shuffle_over_epochs (int | None) – Shuffle the dataset over this many epochs.

  • **kwargs – Additional arguments to the dataset constructor.

  • split_config (str | None)

  • dataset_config (str | None)

Returns:

The loaded dataset

Return type:

BaseCoreDatasetFactory

get_datasets(*, training, split_part, worker_config, subflavor=None, subflavors=None, shuffle_over_epochs_multiplier=1, **kwargs)[source]

Calls megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset) for all innermost datasets and resolves their relative weights to absolute weights.

Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (Literal['train', 'val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • subflavor (str | None) – Set the default subflavor for all datasets.

  • subflavors (Dict[str, Any] | None) – Set the default subflavors for all datasets.

  • shuffle_over_epochs_multiplier (int | None) – Multiply the inner datasets shuffle_over_epochs(_multiplier) by this factor. E.g. if the inner dataset has shuffle_over_epochs_multiplier=2 and this function has shuffle_over_epochs_multiplier=3, the inner dataset will be shuffled over 6 epochs. Shuffling over n epochs guarantees that each sample is seen exactly n times in n epochs of the inner dataset. Use -1 for shuffling over an infinite number of epochs (effectively, this will draw shard slices with replacement).

  • **kwargs – Additional arguments to the dataset constructor.

Returns:

The dataset blending mode and the instantiated core datasets with their weights/repetitions.

Return type:

Tuple[DatasetBlendMode, List[Tuple[BaseCoreDatasetFactory, float | int | None]]]

path: str | EPath
post_initialize(mds_path=None)[source]

Called to finally initialize the dataset.

Parameters:

mds_path (EPath | None)

shuffle_over_epochs_multiplier: int | None
split_config: str
split_part: str | None
subflavor: str | None
subflavors: Dict[str, Any] | None
class megatron.energon.DatasetLoaderInterface[source]

Bases: ABC

General interface for a dataset loader.

abstract get_datasets(*, training, split_part, worker_config, subflavor=None, subflavors=None, shuffle_over_epochs_multiplier=1, **kwargs)[source]

Calls megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset) for all innermost datasets and resolves their relative weights to absolute weights.

Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (Literal['train', 'val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • subflavor (str | None) – Set the default subflavor for all datasets.

  • subflavors (Dict[str, Any] | None) – Set the default subflavors for all datasets.

  • shuffle_over_epochs_multiplier (int | None) – Multiply the inner datasets shuffle_over_epochs(_multiplier) by this factor. E.g. if the inner dataset has shuffle_over_epochs_multiplier=2 and this function has shuffle_over_epochs_multiplier=3, the inner dataset will be shuffled over 6 epochs. Shuffling over n epochs guarantees that each sample is seen exactly n times in n epochs of the inner dataset. Use -1 for shuffling over an infinite number of epochs (effectively, this will draw shard slices with replacement).

  • **kwargs – Additional arguments to the dataset constructor.

Returns:

The dataset blending mode and the instantiated core datasets with their weights/repetitions.

Return type:

Tuple[DatasetBlendMode, List[Tuple[BaseCoreDatasetFactory, float | int | None]]]

abstract post_initialize(mds_path=None)[source]

Called to finally initialize the dataset.

Parameters:

mds_path (EPath | None)

prepare(split_part=None)[source]

Prepares the loader by creating caches and other necessary structures on disk.

Parameters:

split_part (str | None) – Name of the split to load.

Returns:

List of paths to the cache paths. This is used for cleanup.

Return type:

Sequence[EPath]

class megatron.energon.DefaultDecoderWebdatasetFactory(path, *, auto_decode=True, image_decode='torchrgb', ignore_decoder_errors=False, audio_clip_duration=1, audio_num_clips=-1, video_decode='torch', video_decode_audio=False, video_num_frames=64, video_out_frame_size=(224, 224), **kwargs)[source]

Bases: DefaultGenericWebdatasetFactory[T_sample], Generic[T_sample]

Extends the default webdataset loading with decoding of contained files, such as images, videos or nested containers.

Parameters:
  • path (EPath)

  • auto_decode (bool)

  • image_decode (Literal['l8', 'rgb8', 'rgba8', 'l', 'rgb', 'rgba', 'torchl8', 'torchrgb8', 'torchrgba8', 'torchl', 'torchrgb', 'torch', 'torchrgba', 'pill', 'pil', 'pilrgb', 'pilrgba'])

  • ignore_decoder_errors (bool)

  • audio_clip_duration (int)

  • audio_num_clips (int)

  • video_decode (Literal['torch', 'AVData'])

  • video_decode_audio (bool)

  • video_num_frames (int)

  • video_out_frame_size (tuple)

audio_clip_duration: int

Duration of each audio clip in seconds.

audio_num_clips: int

Number of audio clips to extract (-1 for all).

config()[source]
Return type:

Dict[str, Any]

ignore_decoder_errors: bool

If true, ignore errors when decoding.

image_decode: Literal['l8', 'rgb8', 'rgba8', 'l', 'rgb', 'rgba', 'torchl8', 'torchrgb8', 'torchrgba8', 'torchl', 'torchrgb', 'torch', 'torchrgba', 'pill', 'pil', 'pilrgb', 'pilrgba']

Image decoding result type

load_sample(sample)[source]

Loads the sample from the dataset.

Parameters:

sample (FilteredSample)

Return type:

T_sample

video_decode: Literal['torch', 'AVData']

If “AVData”, returns an AVData instance for flexible decoding. If “torch”, returns decoded VideoData.

video_decode_audio: bool

Whether to decode audio from video files.

video_num_frames: int

Number of video frames to extract.

video_out_frame_size: tuple

Output size for video frames (width, height).

class megatron.energon.DefaultGenericWebdatasetFactory(path, *, subflavor=None, subflavors=None, field_map=None, sample_loader=None, part_filter=None, **kwargs)[source]

Bases: BaseWebdatasetFactory[T_sample], Generic[T_sample]

Default implementation of webdataset for generic samples and the generic config interface for use with dataset.yaml.

Parameters:
  • path (EPath)

  • subflavor (str | None)

  • subflavors (Dict[str, Any] | None)

  • field_map (Dict[str, str] | None)

  • sample_loader (str | Callable[[dict], dict] | None)

  • part_filter (str | List[str] | Callable[[str], bool] | None)

config()[source]
Return type:

Dict[str, Any]

load_sample(sample)[source]

Loads the sample from the dataset.

Parameters:

sample (FilteredSample)

Return type:

T_sample

class megatron.energon.DefaultTaskEncoder(*, encoded_sample_type=None, raw_batch_type=None, batch_type=None)[source]

Bases: TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch], ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]

The default task encoder supports automagically mapping to target types. You may override any methods to customize the behavior. By default, encode_sample is the identity function, batch calls _batch with the type of the first sample, and encode_batch is also the identity function. If you set any of encoded_sample_type, ‘raw_batch_type’ or batch_type, the corresponding method return that type, where it automatically maps the fields (by name) to your new type.

Parameters:
  • encoded_sample_type (Type[T_encoded_sample] | None)

  • raw_batch_type (Type[T_raw_batch] | None)

  • batch_type (Type[T_batch] | None)

batch(samples)[source]

Batch a list of samples. The default implementation uses default batching to convert to _batch_type.

Parameters:

samples (List[T_encoded_sample])

Return type:

T_raw_batch

encode_batch(batch)[source]

Encode a batch of samples. The default implementation converts to the _encoded_batch_type.

Parameters:

batch (T_raw_batch)

Return type:

T_batch | Generator[T_batch, None, None]

encode_sample(sample)[source]

Encode a single sample. The default implementation converts to the _encoded_sample_type.

Parameters:

sample (T_sample)

Return type:

T_encoded_sample | Generator[T_encoded_sample, None, None]

class megatron.energon.EpochizeDataset(dataset, length, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

Uses the base dataset, and creates one epoch, which has length samples. Keeps the underlying dataset iterator alive over epochs (i.e. if it is an infinite dataset, it will keep the state). Repeats the underlying dataset if the iterator is exhausted.

Parameters:
config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

length: int
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.FilterDataset(dataset, *, filter_fn, filter_fn_config=None, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

This dataset wrapper applies a custom filter function to each sample and does not yield filtered samples.

Parameters:
  • dataset (SavableDataset[T_sample])

  • filter_fn (Callable[[T_sample], bool])

  • filter_fn_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • worker_config (WorkerConfig)

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

filter_fn: Callable[[T_sample], bool]
filter_fn_config: Dict[str, Any] | Callable[[], Dict[str, Any]] | None
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(index)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

index (Tuple[str | int | tuple, ...])

Return type:

T_sample

class megatron.energon.GcDataset(dataset, *, worker_config, every_n_iter=10, freeze=True)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

Applies a garbage collection step. This is needed, because python garbage collection does not work well with very large objects, such as tensors. This case happens, if there are a few hundred objects created and released every epoch (some of them being (large) tensors), where a lot of them are alive at the same time, but released later. In that case, those objects may end up in gc generation 2, where they may live until a lot of objects have been created, until automatic garbage collection of gen2 is actually triggered. To avoid this memory leak, gc.collect() is best to be called regularly. In addition, if gc.freeze() is used before the loop, it will remove the objects currently alive from garbage collection checks, thus making the gc faster.

Parameters:
config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

every_n_iter: int
freeze: bool
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.GroupBatchDataset(dataset, fixed_batch_size, sample_group_key, batcher, *, batcher_stateless=False, batcher_config=None, drop_last=False, error_handler=<function log_exception>, worker_config)[source]

Bases: BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]

This dataset wrapper transforms a dataset of samples into a dataset of batches, grouped by some criterion. The length is not correct, as this function can not predict the number of batches as there is no fixed batch size, instead it returns the inner dataset size. An example use case is: Image-Text samples, which are to be grouped by the image size into three size categories (e.g. 128x128, 256x256, 512x512) for efficient augmentation and batching.

Parameters:
  • dataset (SavableDataset[T_batch_sample])

  • fixed_batch_size (int | None)

  • sample_group_key (Callable[[T_batch_sample], Tuple[Hashable, int | None]])

  • batcher (Callable[[List[T_batch_sample]], T_batch])

  • batcher_stateless (bool)

  • batcher_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • drop_last (bool)

  • error_handler (Callable[[Exception, List[T_batch_sample]], None])

  • worker_config (WorkerConfig)

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

batcher: Callable[[List[T_batch_sample]], T_batch]
can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

drop_last: bool
error_handler: Callable[[Exception, List[T_batch_sample]], None]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(index)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

index (Tuple[str | int | tuple, ...])

Return type:

T_batch

restore_state(state)[source]

Restores the state of the dataset. This will restore the state of all fields in the _savable_fields tuple. Can only be called in a worker process.

Parameters:

state (FlexState) – The state of the dataset as savable object. If None, restore initial state.

Return type:

None

sample_group_key: Callable[[T_batch_sample], Tuple[Hashable, int | None]]
save_state()[source]

Saves the state of the dataset. This will save and return the state of all fields in the _savable_fields tuple. Can only be called in a worker process.

Return type:

FlexState

class megatron.energon.ImageClassificationSample(__key__, __restore_key__, __subflavor__, __subflavors__, image, label=None, label_name=None)[source]

Bases: Sample

Sample type for classifying an image.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • image (Tensor)

  • label (int | None)

  • label_name (str | None)

image: Tensor

The input image tensor in the shape (C, H, W)

label: int | None

The class label of the image

label_name: str | None

The class label of the image

class megatron.energon.ImageClassificationWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[ImageClassificationSample]

Parameters:

path (EPath)

class megatron.energon.ImageSample(__key__, __restore_key__, __subflavor__, __subflavors__, image)[source]

Bases: Sample

Sample type for an image, e.g. for image reconstruction.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • image (Tensor)

image: Tensor

The input image tensor in the shape (C, H, W)

class megatron.energon.ImageWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[ImageSample]

Parameters:

path (EPath)

class megatron.energon.InterleavedSample(__key__, __restore_key__, __subflavor__, __subflavors__, sequence)[source]

Bases: Sample

Sample type for interleaved media such as text with images.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • sequence (List[Tensor | str])

sequence: List[Tensor | str]

The interleaved media (either torch.tensor for an image, or str for text)

class megatron.energon.InterleavedWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[InterleavedSample]

Parameters:

path (EPath)

class megatron.energon.IterMapDataset(dataset, iter_map_fn, *, len_map_fn=<function IterMapDataset.<lambda>>, error_handler=<function log_exception>, stateless_iter_fn=False, iter_map_fn_config=None, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]

This dataset wrapper applies a custom function to transform the stream of samples and yield a new stream of samples. If used in a savable dataset context, it is critical, that iter_map_fn is either stateless, or that the state of the iter_map_fn is saved and restored externally.

Parameters:
  • dataset (SavableDataset[T_sample])

  • iter_map_fn (Callable[[Iterator[T_sample]], Iterator[T_sample_out]])

  • len_map_fn (Callable[[int], int])

  • error_handler (Callable[[Exception, T_sample | None], None])

  • stateless_iter_fn (bool)

  • iter_map_fn_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • worker_config (WorkerConfig)

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

error_handler: Callable[[Exception, T_sample | None], None]
iter_map_fn: Callable[[Iterator[T_sample]], Iterator[T_sample_out]]
iter_map_fn_config: Dict[str, Any] | Callable[[], Dict[str, Any]] | None
len_map_fn: Callable[[int], int]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(index)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

index (Tuple[str | int | tuple, ...])

Return type:

T_sample

stateless_iter_fn: bool
class megatron.energon.JoinedWebdatasetFactory(inner_datasets, *, training, worker_config, shuffle_over_epochs=1, parallel_shard_iters=None, max_samples_per_sequence=None, join_index, joiner, handler=<function reraise_exception>)[source]

Bases: BaseCoreDatasetFactory[T_sample], Sharder, ErrorHandler[T_sample], Generic[T_sample], ABC

Base class for all webdataset loaders. Applies proper sharding across workers. Can join multiple datasets.

Parameters:
  • inner_datasets (List[BaseWebdatasetFactory])

  • training (bool)

  • worker_config (WorkerConfig)

  • shuffle_over_epochs (int | None)

  • parallel_shard_iters (int | None)

  • max_samples_per_sequence (int | None)

  • join_index (EPath)

  • joiner (Type[T_sample] | Callable[[...], T_sample])

  • handler (Callable[[Exception, str | None], None])

build(worker_rotation_offset=0)[source]
Parameters:

worker_rotation_offset (int)

Return type:

SavableDataset[T_sample]

config()[source]
Return type:

Dict[str, Any]

inner_dataset_keys: List[str] | None
inner_datasets: List[BaseWebdatasetFactory]
join_index: EPath
load_sample(samples)[source]
Parameters:

samples (RawSampleData)

Return type:

T_sample

max_samples_per_sequence: int | None
parallel_shard_iters: int | None
part_datasets: SavableDataset[T_sample]
property paths: List[EPath]
shards: List[Sequence[ShardInfo]]
shuffle_over_epochs: int | None = 1
training: bool
worker_config: WorkerConfig
class megatron.energon.LimitDataset(dataset, length, *, reset_after_epoch=False, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

Limits the length of the dataset.

Parameters:
config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

current_offset: int
length: int
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

worker_has_samples()[source]

Returns True if the worker’s split has samples. This is used to determine if this dataset yields anything.

Return type:

bool

class megatron.energon.LogSampleDataset(dataset, mode, worker_config, get_keys_fn=<function default_get_keys>)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

This dataset logs every yielded sample to the debug logs.

Parameters:
  • dataset (SavableDataset[T_sample])

  • mode (Literal['train', 'val'])

  • worker_config (WorkerConfig)

  • get_keys_fn (Callable[[T_sample], List[str] | None])

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

get_keys_fn: Callable[[T_sample], List[str] | None]
mode: Literal['train', 'val']
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.MapDataset(dataset, map_fn, *, error_handler=<function log_exception>, stateless_map_fn=False, map_fn_config=None, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]

This dataset wrapper applies a custom function to transform each sample.

Parameters:
  • dataset (SavableDataset[T_sample])

  • map_fn (Callable[[T_sample], T_sample_out | Generator[T_sample_out, None, None]])

  • error_handler (Callable[[Exception, T_sample], None])

  • stateless_map_fn (bool)

  • map_fn_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • worker_config (WorkerConfig)

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

error_handler: Callable[[Exception, T_sample], None]
map_fn: Callable[[T_sample], T_sample_out | Generator[T_sample_out, None, None]]
map_fn_config: Dict[str, Any] | Callable[[], Dict[str, Any]] | None
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(index)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

index (Tuple[str | int | tuple, ...])

Return type:

T_sample_out

stateless_map_fn: bool
class megatron.energon.Metadataset(path, splits)[source]

Bases: DatasetLoaderInterface

Main entry for metadataset.

Parameters:
  • path (EPath | str)

  • splits (Dict[str, MetadatasetBlender])

get_datasets(*, training, split_part, worker_config, subflavor=None, subflavors=None, shuffle_over_epochs_multiplier=1, **kwargs)[source]

Calls megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset) for all innermost datasets and resolves their relative weights to absolute weights.

Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (Literal['train', 'val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • subflavor (str | None) – Set the default subflavor for all datasets.

  • subflavors (Dict[str, Any] | None) – Set the default subflavors for all datasets.

  • shuffle_over_epochs_multiplier (int | None) – Multiply the inner datasets shuffle_over_epochs(_multiplier) by this factor. E.g. if the inner dataset has shuffle_over_epochs_multiplier=2 and this function has shuffle_over_epochs_multiplier=3, the inner dataset will be shuffled over 6 epochs. Shuffling over n epochs guarantees that each sample is seen exactly n times in n epochs of the inner dataset. Use -1 for shuffling over an infinite number of epochs (effectively, this will draw shard slices with replacement).

  • **kwargs – Additional arguments to the dataset constructor.

Returns:

The dataset blending mode and the instantiated core datasets with their weights/repetitions.

Return type:

Tuple[DatasetBlendMode, List[Tuple[BaseCoreDatasetFactory, float | int | None]]]

post_initialize(mds_path=None)[source]

Called to finally initialize the dataset.

Parameters:

mds_path (EPath | None)

class megatron.energon.MetadatasetV2(path: megatron.energon.epathlib.epath.EPath, splits: Dict[str, megatron.energon.metadataset.metadataset_v2.MetadatasetBlend | megatron.energon.metadataset.metadataset_v2.MetadatasetBlendEpochized | megatron.energon.metadataset.metadataset_v2.MetadatasetJoin | megatron.energon.metadataset.metadataset_v2.DatasetReference])[source]

Bases: DatasetLoaderInterface

Parameters:
  • path (EPath)

  • splits (Dict[str, MetadatasetBlend | MetadatasetBlendEpochized | MetadatasetJoin | DatasetReference])

get_datasets(*, training, split_part, worker_config, subflavor=None, subflavors=None, shuffle_over_epochs_multiplier=1, **kwargs)[source]

Calls megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset) for all innermost datasets and resolves their relative weights to absolute weights.

Parameters:
  • training (bool) – If true, apply training randomization.

  • split_part (Literal['train', 'val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • subflavor (str | None) – Set the default subflavor for all datasets.

  • subflavors (Dict[str, Any] | None) – Set the default subflavors for all datasets.

  • shuffle_over_epochs_multiplier (int | None) – Multiply the inner datasets shuffle_over_epochs(_multiplier) by this factor. E.g. if the inner dataset has shuffle_over_epochs_multiplier=2 and this function has shuffle_over_epochs_multiplier=3, the inner dataset will be shuffled over 6 epochs. Shuffling over n epochs guarantees that each sample is seen exactly n times in n epochs of the inner dataset. Use -1 for shuffling over an infinite number of epochs (effectively, this will draw shard slices with replacement).

  • **kwargs – Additional arguments to the dataset constructor.

Returns:

The dataset blending mode and the instantiated core datasets with their weights/repetitions.

Return type:

Tuple[DatasetBlendMode, List[Tuple[BaseCoreDatasetFactory, float | int | None]]]

path: EPath
post_initialize(mds_path=None)[source]

Called to finally initialize the dataset.

Parameters:

mds_path (EPath | None)

prepare(split_part=None)[source]

Prepares the loader by creating caches and other necessary structures on disk.

Parameters:

split_part (str | None) – Name of the split to load.

Returns:

List of paths to the cache paths. This is used for cleanup.

Return type:

Sequence[EPath]

splits: Dict[str, MetadatasetBlend | MetadatasetBlendEpochized | MetadatasetJoin | DatasetReference]
class megatron.energon.MixBatchDataset(*dataset_weights, batch_size, batch_mix_fn=<function MixBatchDataset.<lambda>>, worker_config)[source]

Bases: BaseWrapperDataset[T_batch_in, T_batch], Generic[T_batch_in, T_batch]

This dataset wrapper blends multiple iterable datasets together give a weight. The datasets may be infinite. This dataset is always infinite. Effectively combines megatron.energon.BlendDataset and megatron.energon.BatchDataset.

Parameters:
  • dataset_weights (Tuple[SavableDataset[T_batch_in], float])

  • batch_size (int)

  • batch_mix_fn (Callable[[List[T_batch_in]], T_batch | Generator[T_batch, None, None]])

  • worker_config (WorkerConfig)

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.MultiChoiceVQASample(__key__, __restore_key__, __subflavor__, __subflavors__, image, context, choices=None, correct_choice_idx=0)[source]

Bases: Sample

Sample type for visual question answering.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • image (Tensor)

  • context (str)

  • choices (List[str] | None)

  • correct_choice_idx (int)

choices: List[str] | None

The candidate answers.

context: str

The context/question for the image

correct_choice_idx: int

The index of the correct answer.

image: Tensor

The input image tensor in the shape (C, H, W)

class megatron.energon.MultiChoiceVQAWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[MultiChoiceVQASample]

Parameters:

path (EPath)

class megatron.energon.OCRSample(__key__, __restore_key__, __subflavor__, __subflavors__, image, text, block_boxes=None, block_classes=None, block_text=None, lines_boxes=None, lines_text=None, words_boxes=None, words_text=None, chars_boxes=None, chars_text=None)[source]

Bases: Sample

Sample type for optical character recognition.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • image (Tensor)

  • text (str)

  • block_boxes (Tensor | None)

  • block_classes (Tensor | List[str] | None)

  • block_text (List[str] | None)

  • lines_boxes (Tensor | None)

  • lines_text (List[str] | None)

  • words_boxes (Tensor | None)

  • words_text (List[str] | None)

  • chars_boxes (Tensor | None)

  • chars_text (List[str] | None)

block_boxes: Tensor | None

The bounding boxes of the blocks in the image float(N, 4|5<x, y, w, h>)

block_classes: Tensor | List[str] | None

The classes of the blocks in the image int(N, 1<block_class>)

block_text: List[str] | None

The text contained in each block (N,)

chars_boxes: Tensor | None

The bounding boxes of the chars in the image float(N, 4|5<x, y, w, h[, confidence]>)

chars_text: List[str] | None

The character contained in each char (N,)

image: Tensor

The input image tensor in the shape (C, H, W)

lines_boxes: Tensor | None

The bounding boxes of the lines in the image float(N, 4|5<x, y, w, h[, confidence]>)

lines_text: List[str] | None

The text contained in each line (N,)

text: str

The text contained in the image

words_boxes: Tensor | None

The bounding boxes of the words in the image float(N, 4|5<x, y, w, h[, confidence]>)

words_text: List[str] | None

The text contained in each word (N,)

class megatron.energon.OCRWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[OCRSample]

Parameters:

path (EPath)

class megatron.energon.PackingDataset(dataset, buffer_size, pre_packer, final_packer, *, final_packer_stateless=False, packer_config=None, error_handler=<function log_exception>, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_batch_sample], Generic[T_sample, T_batch_sample]

This dataset wrapper transforms samples of a dataset into chunks/packs of samples, which are then combined into a batch.

Parameters:
  • dataset (SavableDataset[T_sample])

  • buffer_size (int)

  • pre_packer (Callable[[List[T_sample]], List[List[T_sample]]])

  • final_packer (Callable[[List[T_sample]], T_batch_sample])

  • final_packer_stateless (bool)

  • packer_config (Dict[str, Any] | Callable[[], Dict[str, Any]] | None)

  • error_handler (Callable[[Exception, List[T_sample]], None])

  • worker_config (WorkerConfig)

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

buffer_size: int
can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

error_handler: Callable[[Exception, List[T_sample]], None]
final_packer: Callable[[List[T_sample]], T_batch_sample]
final_packer_stateless: bool
packer_config: Dict[str, Any] | Callable[[], Dict[str, Any]] | None
pre_packer: Callable[[List[T_sample]], List[List[T_sample]]]
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(index)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

index (Tuple[str | int | tuple, ...])

Return type:

T_sample

class megatron.energon.RepeatDataset(dataset, *, repeats=None, restart=True, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

This dataset repeats the inner dataset indefinitely or a specific number of repeats.

Parameters:
config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

repeats: int | float | None
reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

class megatron.energon.Sample(__key__, __restore_key__, __subflavor__, __subflavors__)[source]

Bases: ABC, PinMemoryMixin, ExtendableDataclassMixin

An abstract base class for one element of a batch. Each task should derive a specific subclass as a @dataclass, like megatron.energon.CaptioningBatchSample, and add the input and output fields as needed for training.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

classmethod derive_from(base_sample, **kwargs)[source]

Uses the base fields of Sample from base_sample (i.e. __key__, __restore_key__, __subflavor__, __subflavors__) and creates a new sample with the kwargs as fields. This is useful for creating new samples, while keeping the metadata of the base sample.

Parameters:
  • base_sample (Sample) – The base sample to copy the base fields / metadata from.

  • kwargs – The fields of the new sample.

Returns:

The new sample.

Return type:

T_sample

classmethod from_joined(*args, **kwargs)[source]

Creates a sample from joined samples. The samples are either passed as positional arguments or as keyword arguments. The first sample is the primary sample, which is used to initialize the key and subflavors.

In the default implementation, the joined samples’ fields will be joined together, such that latter joined samples will update the fields last (i.e. take precedence), except for the key and subflavors. The restore key is later set externally.

Parameters:
  • args (Sample | None) – The samples to join (either this or kwargs is specified).

  • kwargs (Sample | None) – The samples to join (either this or args is specified). Not supported for the default implementation. Overwriting implementations may use this.

Returns:

The joined constructed sample.

Return type:

T_sample

class megatron.energon.SavableDataLoader(dataset, *, checkpoint_every_sec=60, checkpoint_every_min_n_samples=None, n_checkpoints=None, gc_collect_every_n_steps=10, gc_freeze_at_start=True, prefetch_factor=2)[source]

Bases: DataLoader[T], Generic[T]

DataLoader that supports saving and restoring the state of the dataset. When restoring, the dataloader and dataset must be instantiated with the exactly same parameters.

How this works (for no worker processes)

  1. The state of the dataset is saved using megatron.energon.SavableDataset.save_state()

  2. (for compatibility) The state of the dataset is converted to using inner arrays using megatron.energon.SavableDataset.merge_states().

  3. The state can be restored using megatron.energon.SavableDataset.restore_state() given the previously saved (and merged) state.

How this works (for worker processes)

  • First issue is, that worker processes work with internal queues between processes to pass loaded samples to the main process (also to perform collating). This means that the whole state of the dataset is not directly accessible from the main process.

  • To solve this issue, the dataset regularly saves a checkpoint of its state to be able to resume from that state (and skip the samples that have already been yielded).

  • To have a consistent state, the sample index from the latest yielded samples is saved for all worker instances. Thus, the main process knows exactly which sample indexes should come next from which worker.

  • Internally, pytorch iterates through the workers in order to retrieve the next worker’s samples. Unfortunately, that next worker index cannot be restored in pytorch’s dataloader, thus the workers are shifted internally by that offset (see megatron.energon.WorkerConfig.worker_id_offset).

  1. The dataset is wrapped in a megatron.energon.SavableDatasetWrapper. This allows the main process to communicate with the worker and send commands to the workers and retrieve the results.

  2. The state of the dataset is saved using megatron.energon.SavableDatasetWrapper.get_checkpoint(). This gives the last checkpoint from the requested sample index and stores the offset (i.e. number of samples to skip) from that checkpoint.

  3. The state is merged using megatron.energon.SavableDatasetWrapper.merge_checkpoints(). This merges the states of all workers and returns a single state that can be used to restore the state of the dataset.

  1. The state can be restored using megatron.energon.SavableDatasetWrapper.restore_state() before a worker is started, such that all workers initially receive the same state array. The worker firstly sets the worker index offset, then uses its (shifted) own index to get its required state from the merged state array.

can_restore_sample()[source]
Return type:

bool

cmd_queues: List[Queue]

The queues used to send commands to the workers

config()[source]

Get the configuration, which defines the dataset. Useful in conjunction with save_state and restore_state to match the configuration as well.

dataset: SavableDatasetWrapper[T] | SimpleSavableDatasetWrapper[T]

The wrapped dataset. For multiprocessing, this is a megatron.energon.SavableDatasetWrapper

id: int = 0

Class instance id

static next_id()[source]
Return type:

int

restore_sample(sample_key)[source]

Restores a sample from a key. This is useful to debug the dataset.

Parameters:

sample_key (Tuple[str | int | tuple, ...])

Return type:

T

restore_state(state)[source]

Deprecated. Use restore_state_global (or restore_state_rank) instead.

Parameters:

state (Sequence[SavableDataLoaderState | None] | None)

Return type:

None

restore_state_global(state, *, src_rank=None)[source]

Restores the saved state from save_state_global (in torch distributed setup). The global state needs be loaded on every rank that has a data loader instance.

Optionally, one can specify a src_rank and only provide the state once. In case of multiple data parallel groups, you must provide the state once in each data parallel group. In this case the src_rank is the rank within the data parallel group.

Parameters:
  • state (Sequence[SavableDataLoaderState | None] | None) – The state to restore, as saved by save_state_global.

  • src_rank (int | None) – The rank from which the state is broadcasted (within the data parallel group, if using DP groups).

Return type:

None

restore_state_rank(state)[source]

Restores the saved state for the current rank.

Parameters:

state (SavableDataLoaderState | None) – The state to restore, as saved by save_state_rank.

Return type:

None

result_queues: List[Queue]

The queues used to receive results from the workers

save_state(dst_rank)[source]

Deprecated. Use save_state_global (or save_state_rank) instead.

Parameters:

dst_rank (int)

Return type:

Sequence[SavableDataLoaderState | None] | None

save_state_global(global_dst_rank)[source]

Saves the state of the dataset globally, collecting the state from all ranks using torch distributed. Allows for restoring the state later using restore_state_global, given the result of this method. Typical scenario: Save the state to disk only on the dst_rank, the other ranks do not save the state. Later, restore the state either only loaded on the dst_rank or loading on all ranks separately using restore_state_global.

Note: If you want to save/restore the state per rank separately, use save_state_rank and the corresponding restore_state_rank. Also, these do not rely on torch distributed.

Parameters:

global_dst_rank (int) – The state will be gathered to this rank. The rank refers to the global rank, not the rank within the data parallel group.

Returns:

The state of the dataset (or None, if not on dst_rank).

Return type:

Sequence[SavableDataLoaderState | None] | None

save_state_rank()[source]

Saves the state of the dataset for the current rank. Allows for restoring the state later using restore_state_rank, given the result of this method.

Returns:

The state of the dataset.

Return type:

SavableDataLoaderState | None

worker_config: WorkerConfig

The worker config

Parameters:
  • dataset (SavableDatasetWrapper[T] | SimpleSavableDatasetWrapper[T])

  • checkpoint_every_sec (float)

  • checkpoint_every_min_n_samples (int | None)

  • n_checkpoints (int | None)

  • gc_collect_every_n_steps (int)

  • gc_freeze_at_start (bool)

  • prefetch_factor (int | None)

class megatron.energon.SavableDataset(worker_config)[source]

Bases: IterableDataset[T_sample], Savable, Generic[T_sample], ABC

A dataset that can be saved and restored (i.e. the random state, internal buffers, etc.). I.e. it can be resumed from a checkpoint.

How dataset state saving works:

  1. The dataset state needs to be saved in all forked worker processes which contain a copy of the main dataset instance (see megatron.energon.SavableDataLoader). Each worker returns only its own state.

  2. The main process merges the states via the megatron.energon.SavableDataset.merge_states() method in the main process on the main dataset instance (which doesn’t hold the worker states, as they were forked).

  3. The main process saves the merged state to the checkpoint.

Parameters:

worker_config (WorkerConfig)

assert_can_restore()[source]

Asserts that the dataset can restore a sample from a key.

Return type:

None

can_restore_sample()[source]

Returns True if the dataset can restore a sample from a key.

Return type:

bool

abstract config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

reset_state_deep()[source]

Resets the state of the dataset to the initial state. Can only be called in a worker process.

Return type:

None

abstract reset_state_own()[source]

Resets the state of the dataset to the initial state. Can only be called in a worker process.

Return type:

None

restore_sample(index)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

index (Tuple[str | int | tuple, ...])

Return type:

T_sample

restore_state(state)[source]

Restores the state of the dataset. This will restore the state of all fields in the _savable_fields tuple. Can only be called in a worker process.

Parameters:

state (FlexState) – The state of the dataset as savable object. If None, restore initial state.

Return type:

None

save_state()[source]

Saves the state of the dataset. This will save and return the state of all fields in the _savable_fields tuple. Can only be called in a worker process.

Return type:

FlexState

worker_config: WorkerConfig
abstract worker_has_samples()[source]

Returns True if the worker’s split has samples. This is used to determine if this dataset yields anything.

Return type:

bool

class megatron.energon.ShuffleBufferDataset(dataset, size, *, worker_config)[source]

Bases: BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]

Shuffle buffer for the dataset.

Parameters:
config()[source]

Return a config dict that can be used to check if datasets have the same settings. Variables in dicts starting with “_” represent a possibly changable setting, like a full path which may be changed.

Return type:

Dict[str, Any]

reset_state_own()[source]

Resets the state of the dataset, excl. the inner datasets.

Return type:

None

restore_sample(index)[source]

Generic key type, because it might be either an integer (for a core dataset), or something more complex (e.g. for blended datasets).

Default raises an exception (assumed non-deterministic if not implemented, does not guarantee determinism).

Parameters:

index (Tuple[str | int | tuple, ...])

Return type:

T_sample

size: int
class megatron.energon.SimilarityInterleavedSample(__key__, __restore_key__, __subflavor__, __subflavors__, images, texts, audio=None, video=None, similarity_matrix=None, matched_text_indices=None)[source]

Bases: Sample

Sample type for interleaved media such as text with images, but without image-text alignment. That alignment has to be assigned from the similarity matrix.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • images (List[Tensor])

  • texts (List[str])

  • audio (List[Tensor] | None)

  • video (List[Tensor] | None)

  • similarity_matrix (Tensor | None)

  • matched_text_indices (List[int] | None)

audio: List[Tensor] | None

The optional audio samples of the sequence

images: List[Tensor]

The images of the sequence

matched_text_indices: List[int] | None

The index within texts representing the sentence that this image is matched to

similarity_matrix: Tensor | None

Similarity matrix between image and text entries in the sequence

texts: List[str]

The texts of the sequence

video: List[Tensor] | None

The optional video frames of the sequence

class megatron.energon.SimilarityInterleavedWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[SimilarityInterleavedSample]

Parameters:

path (EPath)

exception megatron.energon.SkipSample[source]

Bases: Exception

Exception to raise in the map_fn to skip a sample.

class megatron.energon.StandardWebdatasetFactory(path, *, sample_type, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[T_sample], Generic[T_sample]

This dataset sample loader factory uses the sample type e.g. given from a dataset.yaml, and applies the default loading logic, which includes decoding images, videos and containers.

Parameters:
  • path (EPath)

  • sample_type (Type[T_sample])

class megatron.energon.TaskEncoder[source]

Bases: ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]

Base class for task encoders.

Task encoding follows these steps:
  1. Data comes from the dataset

  2. megatron.energon.TaskEncoder.encode_sample() is called on each sample

  3. megatron.energon.TaskEncoder.batch() is called on the list of encoded samples

  4. megatron.energon.TaskEncoder.encode_batch() is called on the batch

  5. megatron.energon.TaskEncoder.to_device() is called on the encoded batch

  6. resulting encoded batch is passed to the network

batch(samples)[source]

Move a batch to a device. May raise megatron.energon.SkipSample to skip a batch.

Parameters:

samples (List[T_encoded_sample])

Return type:

T_raw_batch

batch_group_criterion(sample)[source]

Return a group criterion for the sample. Default implementation does not group (effectively, it returns a single value (None, None), thus only one group is used). Returns the key of the bucket to put this sample into, and the size of the bucket (=batch size). The bucket size must always be the same for the same bucket key.

May raise megatron.energon.SkipSample to skip a batch.

Parameters:

sample (T_encoded_sample)

Return type:

Tuple[Hashable, int | None]

build_batch(dataset, *, batch_size, batch_drop_last=False, packing_buffer_size=None, worker_config)[source]

Applies the batcher to the dataset.

Parameters:
  • dataset (SavableDataset[T_encoded_sample])

  • batch_size (int | None)

  • batch_drop_last (bool)

  • packing_buffer_size (int | None)

  • worker_config (WorkerConfig)

Return type:

SavableDataset[T_raw_batch]

build_cook_crude_sample(dataset, *, worker_config)[source]

Applies the sample cooker to the dataset if we have cookers registered.

Parameters:
Return type:

SavableDataset[T_sample]

build_encode_sample(dataset, *, worker_config)[source]

Applies the sample encoder to the dataset.

Parameters:
Return type:

SavableDataset[T_encoded_sample]

build_train_datasets(*, datasets, worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, virtual_epoch_length=0, shuffle_buffer_size=None, blend_mode=DatasetBlendMode.NONE, repeat=True)[source]

Combines train datasets to a single dataset.

Parameters:
  • datasets (List[Tuple[BaseCoreDatasetFactory[T_sample], float | int | None]])

  • worker_config (WorkerConfig)

  • batch_size (int | None)

  • batch_drop_last (bool)

  • packing_buffer_size (int | None)

  • virtual_epoch_length (int)

  • shuffle_buffer_size (int | None)

  • blend_mode (DatasetBlendMode)

  • repeat (bool)

Return type:

SavableDataset[T_batch]

build_val_datasets(*, datasets, worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, limit=None)[source]

Combines val datasets to a single dataset.

Parameters:
  • datasets (List[BaseCoreDatasetFactory[T_sample]])

  • worker_config (WorkerConfig)

  • batch_size (int)

  • batch_drop_last (bool)

  • packing_buffer_size (int | None)

  • limit (int | None)

Return type:

SavableDataset[T_batch]

cook_crude_sample(sample)[source]
Parameters:

sample (T_sample | CrudeSample)

Return type:

T_sample

cookers: Sequence[Cooker] = ()
property current_batch_index: int

Returns the current index for the next batch yielded from the current worker. Each batch on the current rank will get a strictly increasing unique number. Counting happens on each rank separately (i.e. each rank will get the same numbers for same batch index).

property current_sample_index: int

Returns the current index for the next sample yielded from the current routine (e.g. for encode_sample, batch, or encode_batch). Each routine will get a number representing the number of calls to that function. Across workers, this number will be unique, but it is not synced across workers, thus it may raise in different intervals (e.g. if batching does not work the same for all batches). When restoring a sample, this number is also restored and can be relied on for deterministic randomness reproduction of a sample.

encode_batch(batch)[source]

Encode a batch of samples. May raise megatron.energon.SkipSample to skip a batch. Alternatively, this can be a generator that yields (or ignores) new batches.

Parameters:

batch (T_raw_batch)

Return type:

T_batch | Generator[T_batch, None, None]

encode_sample(sample)[source]

Encode a single sample. May raise megatron.energon.SkipSample to skip a sample. Alternatively, this can be a generator that yields (or ignores) new samples.

Parameters:

sample (T_sample)

Return type:

T_encoded_sample | Generator[T_encoded_sample, None, None]

pack_selected_samples(samples)[source]

Given one set of samples to pack, returns the final packed sample. Packing is only active when packing_buffer_size is set. Internally this stage is called “final_packing”.

Parameters:

samples (List[T_sample]) – The samples to pack into a single sample

Return type:

T_sample

Returns: The final packed sample.

select_samples_to_pack(samples)[source]

For packing, selects the samples to be packed together. Packing is only active when packing_buffer_size is set. Internally this stage is called “pre_packing”.

Parameters:

samples (List[T_sample]) – The samples to pre-pack. A full buffer will be passed into the function.

Return type:

List[List[T_sample]]

Returns: The pre-packed samples as a list of lists of samples.

class megatron.energon.TextSample(__key__, __restore_key__, __subflavor__, __subflavors__, text)[source]

Bases: Sample

Sample type for simple text.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • text (str)

text: str

The text of the sample

class megatron.energon.TextWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[TextSample]

Parameters:

path (EPath)

class megatron.energon.VQAOCRWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[VQAOCRSample]

Parameters:

path (EPath)

class megatron.energon.VQASample(__key__, __restore_key__, __subflavor__, __subflavors__, image, context, answers=None, answer_weights=None)[source]

Bases: Sample

Sample type for visual question answering.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • image (Tensor)

  • context (str)

  • answers (List[str] | None)

  • answer_weights (Tensor | None)

answer_weights: Tensor | None

The weights of the possible answers. Optionally available.

answers: List[str] | None

The possible answers. Not set for testing.

context: str

The context/question for the image

image: Tensor

The input image tensor in the shape (C, H, W)

class megatron.energon.VQAWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[VQASample]

Parameters:

path (EPath)

class megatron.energon.VidQASample(__key__, __restore_key__, __subflavor__, __subflavors__, video, context, answers=None, answer_weights=None)[source]

Bases: Sample

Sample type for video question answering.

Parameters:
  • __key__ (str)

  • __restore_key__ (Tuple[str | int | tuple, ...])

  • __subflavor__ (str | None)

  • __subflavors__ (Dict[str, Any] | None)

  • video (VideoData)

  • context (str)

  • answers (List[str] | None)

  • answer_weights (Tensor | None)

answer_weights: Tensor | None

The weights of the possible answers. Optionally available.

answers: List[str] | None

The possible answers. Not set for testing.

context: str

The context/question for the image.

video: VideoData

The video data containing the image and audio info.

class megatron.energon.VidQAWebdataset(path, **kwargs)[source]

Bases: DefaultDecoderWebdatasetFactory[VidQASample]

Parameters:

path (EPath)

class megatron.energon.WorkerConfig(rank, world_size, num_workers, data_parallel_group=None, seed_offset=0, worker_debug_path=None, worker_log_level=0, _worker_debug_file=None, _worker_debug_file_worker_id=None)[source]

Bases: object

Provides information about the current worker and the global configuration. This gives each data parallel rank its proper config. Every rank (up to world_size-1) must be used. If set wrong, the datasets might yield the same data or data might be missing, as data is split over the data parallel ranks with this config! You may set the same rank, if you need multiple ranks to retrieve the same data.

Parameters:
  • rank (int)

  • world_size (int)

  • num_workers (int)

  • data_parallel_group (ProcessGroup | None)

  • seed_offset (int)

  • worker_debug_path (str | None)

  • worker_log_level (int)

  • _worker_debug_file (TextIO | None)

  • _worker_debug_file_worker_id (int | None)

property active_worker_batch_index: int

Returns the current batch index for the actively iterating worker.

active_worker_config: ClassVar[WorkerConfig | None] = None

The current worker config within the current iterating worker

property active_worker_sample_index: int

Returns the current sample index for the actively iterating worker.

assert_worker()[source]

Checks if the current process is a worker (if configured so), and that the workers are properly configured.

config()[source]
Return type:

Dict[str, Any]

data_parallel_group: ProcessGroup | None

If not using all ranks for data parallel, set this to the corresponding group.

static default_worker_config(num_workers=4, data_parallel_group=None)[source]

Returns the default worker config using torch distributed if available. If torch distributed is not available, a single local rank is assumed.

Parameters:
  • num_workers (int)

  • data_parallel_group (ProcessGroup | None)

Return type:

WorkerConfig

global_rank()[source]

Returns the global rank of this worker config but as a global rank, not as a rank within the data parallel group.

Return type:

int

global_worker_id(override_local_worker_id=None)[source]

Returns the global worker index by multiplying the rank with the number of workers. Alternatively, you can override the local worker id.

Parameters:

override_local_worker_id (int, optional) – The local worker id to override. None means the current worker, which is the default.

Return type:

int

num_workers: int

The number of workers per rank. May be 0 to disable worker processes.

rank: int

The data parallel rank/id of the current process.

rank_worker_id()[source]

Returns the self worker id within the current rank.

Return type:

int

seed_offset: int
should_log(level)[source]
Parameters:

level (int)

Return type:

bool

worker_activate(sample_index, override_global_rank=None)[source]

Activates the worker config for the current worker and sets it as actively iterating. Must be called before next() call on the datasets.

Parameters:
  • sample_index (int)

  • override_global_rank (int | None)

worker_deactivate()[source]

Deactivates the worker config for the current worker and deactivates it for iterating. Must be called after next() call on the datasets.

worker_debug_path: str | None
worker_id_offset: ClassVar[int] = 0
worker_log(data)[source]

Logs the given data to the worker debug file.

Parameters:

data (dict)

Return type:

None

worker_log_level: int

Log level for worker logging.

worker_pop_sample_index()[source]

Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.

worker_push_sample_index(sample_index)[source]

Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.

Parameters:

sample_index (int)

worker_seed(override_local_worker_id=None)[source]

Returns the seed for the current worker (or a specified worker). Base on the current worker id and the seed offset, compute a seed. Alternatively, you can override the local worker id with a fixed one to pregenerate seeds for multiple workers.

Parameters:

override_local_worker_id (int, optional) – The local worker id to override. None means the current worker, which is the default.

Return type:

int

world_size: int

The total number of data parallel processes.

megatron.energon.basic_sample_keys(crude_sample)[source]

A convenience helper to extract the basic keys from a crude sample, which you will always need to forward to the cooked sample.

Parameters:

crude_sample (dict)

Return type:

dict

megatron.energon.batch_list(batch)[source]

Stack a batch of tensors padded with 0s.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.batch_pad_stack(batch)[source]

Stack a batch of arbitrary-sized tensors padded with 0s.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.batch_stack(batch)[source]

Stack a batch of tensors.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.concat_pad(batch)[source]

Concat a batch of arbitrary-sized tensors padded with 0s.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.generic_batch(batch)[source]

Based on the types/shapes of the batch: Will either pad and stack, or return as list. Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.generic_concat(batch)[source]

Based on the types/shapes of the batch: Will either pad and stack, or return as list. Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field.

Parameters:

batch (List[Any])

Return type:

Any

megatron.energon.get_loader(dataset, *, worker_config=None, prefetch_factor=2)[source]

Get a dataloader for the given dataset.

Parameters:
  • dataset (SavableDataset[T]) – The dataset to create a loader for.

  • worker_config (WorkerConfig | None) – Deprecated. Please pass this to the dataset instead.

  • prefetch_factor (int)

Returns:

The instantiated torch.data.DataLoader, yielding batches from the dataset.

Return type:

BasicDataLoader[T]

megatron.energon.get_savable_loader(dataset, *, worker_config=None, checkpoint_every_sec=60, checkpoint_every_min_n_samples=None, n_checkpoints=None, gc_collect_every_n_steps=10, prefetch_factor=2)[source]

Get a dataloader for the given dataset.

Parameters:
  • dataset (SavableDataset[T]) – The dataset to create a loader for.

  • worker_config (WorkerConfig | None) – Deprecated. Please pass this to the dataset instead.

  • checkpoint_every_sec (float) – This is the time in seconds after which an internal checkpoint is saved. It may take the same duration to restore a checkpoint, but introduces additional overhead during reading data from the dataset, so this should be chosen accordingly. Only applies if using workers.

  • checkpoint_every_min_n_samples (int | None) – Overwrites the minimum number of samples between checkpoints. Defaults to number of workers * 2. Only applies if using workers.

  • n_checkpoints (int | None) – The number of internal checkpoints to keep. Only applies if using workers. If None, computes a suitable value.

  • gc_collect_every_n_steps (int)

  • prefetch_factor (int)

Returns:

The instantiated megatron.energon.SavableDataLoader, yielding batches from the dataset, allowing to save the state of the dataset.

Return type:

SavableDataLoader[T]

megatron.energon.get_train_dataset(path, *, split_part='train', worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, shuffle_buffer_size, max_samples_per_sequence, virtual_epoch_length=0, shuffle_over_epochs_multiplier=1, task_encoder=<megatron.energon.task_encoder.base.DefaultTaskEncoder object>, repeat=True, **kwargs)[source]

Get training data loader with sensible defaults. See get_dataset for more details.

The following recipe will be used:
  • megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset)

  • task_encoder.encode_sample

  • (megatron.energon.MixDataset if mixing)

  • megatron.energon.BatchDataset with task_encoder.batch for collation

  • task_encoder.encode_batch

  • megatron.energon.EpochizeDataset (if virtual_epoch_length is set)

Parameters:
  • path (str | EPath | Path) – Path to the dataset.

  • split_part (Literal['train'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • batch_size (int | None) – Size of a batch. If None, do not batch

  • batch_drop_last (bool) – If true, drop the last batch if it is smaller than batch_size.

  • shuffle_buffer_size (int | None) – Size of the sample shuffle buffer (before task encoding).

  • max_samples_per_sequence (int | None) – If set, limit the number of samples per sample-sequence to this.

  • virtual_epoch_length (int) – If set, the dataset will be epochized to this length (=iterating will be suspended and the for-loop returns, next for-loop continues iterating). Otherwise, the dataset will loop indefinitely.

  • shuffle_over_epochs_multiplier (int | None) – Shuffle the shards over this many epochs.

  • task_encoder (TaskEncoder[Any, Any, Any, T]) – Task encoder to use.

  • repeat (bool) – By default, the inner datasets will loop. If set to False, stop iteration after one epoch. Must only be set to False in conjunction with blend_epochized in the metadataset if one is used.

  • **kwargs – Additional arguments to the dataset constructor.

  • packing_buffer_size (int | None)

Returns:

The dataloader.

Return type:

SavableDataset[T]

megatron.energon.get_val_dataset(path, *, split_part='val', worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, limit=None, task_encoder=<megatron.energon.task_encoder.base.DefaultTaskEncoder object>, **kwargs)[source]

Get the validation/test dataset with sensible defaults. See get_dataset for more details.

The following recipe will be used:
  • megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset)

  • task_encoder.encode_sample

  • (megatron.energon.MixDataset if mixing)

  • megatron.energon.BatchDataset with task_encoder.batch for collation

  • megatron.energon.LimitDataset (if limit is set)

  • task_encoder.encode_batch

Parameters:
  • path (str | EPath | Path) – Path to the dataset.

  • split_part (Literal['val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • batch_size (int) – Size of a batch

  • batch_drop_last (bool) – If true, drop the last batch if it is smaller than batch_size.

  • limit (int | None) – If set, limit the number of batches loaded from the dataset to this.

  • task_encoder (TaskEncoder[Any, Any, Any, T]) – Task encoder to use.

  • **kwargs – Additional arguments to the dataset constructor.

  • packing_buffer_size (int | None)

Returns:

The loaded dataset.

Return type:

SavableDataset[T]

megatron.energon.get_val_datasets(path, *, split_part='val', worker_config, batch_size, batch_drop_last=False, packing_buffer_size=None, limit=None, task_encoder=<megatron.energon.task_encoder.base.DefaultTaskEncoder object>, **kwargs)[source]

Get the validation/test dataset with sensible defaults. See get_dataset for more details.

The following recipe will be used:
  • megatron.energon.dataset_config.get_dataset_from_config() (loads the raw dataset)

  • task_encoder.encode_sample

  • (megatron.energon.MixDataset if mixing)

  • megatron.energon.BatchDataset with task_encoder.batch for collation

  • megatron.energon.LimitDataset (if limit is set)

  • task_encoder.encode_batch

Parameters:
  • path (str | EPath | Path) – Path to the dataset.

  • split_part (Literal['val', 'test'] | str) – Default split part to use.

  • worker_config (WorkerConfig) – Worker configuration to use.

  • batch_size (int) – Size of a batch

  • batch_drop_last (bool) – If true, drop the last batch if it is smaller than batch_size.

  • limit (int | None) – If set, limit the number of batches loaded from the dataset to this.

  • task_encoder (TaskEncoder[Any, Any, Any, T]) – Task encoder to use.

  • **kwargs – Additional arguments to the dataset constructor.

  • packing_buffer_size (int | None)

Returns:

The loaded val datasets, with the source datasets.

Return type:

List[Tuple[SavableDataset[T], BaseCoreDatasetFactory]]

megatron.energon.homogeneous_concat_mix(samples)[source]

Mixes a list of batches into a single batch. The default implementation is to concat the batches if they are all of the same type, otherwise return a list of batches.

Parameters:

samples (List[T_batch_in]) – THe samples to mix.

Returns:

The mixed batch.

Return type:

T_batch

megatron.energon.load_dataset(path, **kwargs)[source]

Loads a (meta)dataset.

Parameters:

path (str | EPath | Path)

Return type:

DatasetLoaderInterface

megatron.energon.prepare_metadataset(path)[source]
Parameters:

path (EPath)

megatron.energon.stateless(fn=None, *, restore_seeds=False)[source]

Decorator to mark a function of the task encoder as restorable.

Parameters:
  • fn (Callable[[...], T_sample] | None) – The function to decorate.

  • restore_seeds (bool) – Whether to restore the seeds for the function. I.e. the seeds are set from the sample index and the worker seed, such that they can be restored when a sample is restored from that function.

Return type:

Callable[[…], T_sample]

Usage:

@stateless
def encode_sample(self, sample: T_sample) -> T_encoded_sample:
    ...

# Or if randomness is used (e.g. for augmentations):
@stateless(restore_seeds=True)
def encode_sample(self, sample: T_sample) -> T_encoded_sample:
    ...
class megatron.energon.task_encoder.cooking.Cooker(cook, is_subflavor=None, has_subflavors=None, condition=None)[source]

Bases: object

A cooker transforms a crude sample (simple dict) into a specific sample type inheriting from Sample.

The cook method performs the transformation, the other fields are used to select the samples which this cooker can transform. If no filters are provided, the cooker will transform any sample.

Parameters:
  • cook (Callable[[dict], Sample])

  • is_subflavor (str | None)

  • has_subflavors (dict | None)

  • condition (Callable[[dict], bool] | None)

condition: Callable[[dict], bool] | None = None
cook: Callable[[dict], Sample]
has_subflavors: dict | None = None
is_match(crude_sample)[source]
Parameters:

crude_sample (CrudeSample)

Return type:

bool

is_subflavor: str | None = None
megatron.energon.task_encoder.cooking.basic_sample_keys(crude_sample)[source]

A convenience helper to extract the basic keys from a crude sample, which you will always need to forward to the cooked sample.

Parameters:

crude_sample (dict)

Return type:

dict