Source code for megatron.energon.flavors.webdataset.default_decoder_webdataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Dict, Generic, Optional, TypeVar

from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset.default_generic_webdataset import (
    DefaultGenericWebdatasetFactory,
)
from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder
from megatron.energon.flavors.webdataset.structs import FilteredSample

T_sample = TypeVar("T_sample", covariant=True)


[docs] class DefaultDecoderWebdatasetFactory(DefaultGenericWebdatasetFactory[T_sample], Generic[T_sample]): """ Extends the default webdataset loading with decoding of contained files, such as images, videos or nested containers. """ # The webdataset decoder function, if to be applied _decoder: Optional[SampleDecoder]
[docs] def __init__( self, path: EPath, *, decoder: Optional[SampleDecoder] = SampleDecoder(), **kwargs, ): """ Factory for the webdataset sample loader including the decoder. Args: path: Path to the dataset (passed to parent) decoder: If provided, use this decoder, otherwise just load raw bytes. **kwargs: Args passed to parent constructor """ self._decoder = decoder super().__init__(path, **kwargs)
[docs] def load_sample(self, sample: FilteredSample) -> T_sample: if self._decoder is not None: sample = self._decoder(sample) return super().load_sample(sample)
[docs] def config(self) -> Dict[str, Any]: return dict( **super().config(), **(self._decoder.config() if self._decoder is not None else {}), )