Source code for megatron.energon.cache.file_store

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Optional, Union

from megatron.energon.cache.base import FileStore, FileStoreDecoder
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset.itar_reader import SqliteITarEntryReader
from megatron.energon.source_info import SourceInfo


[docs] class DecodeFileStore(FileStore[Any]): """Used to wrap a FileStore and decode the data on access."""
[docs] def __init__( self, inner_reader: FileStore[bytes], *, decoder: FileStoreDecoder, ): """ Args: inner_reader: The FileStore to wrap. decoder: The decoder to apply to every item read from the FileStore. """ self.inner_reader = inner_reader self.decoder = decoder
def __getitem__(self, fname: str) -> tuple[Any, SourceInfo]: data, source_info = self.inner_reader[fname] return self.decoder.decode(fname, data), source_info
[docs] def get_path(self) -> str: return self.inner_reader.get_path()
def __str__(self): return f"DecodeFileStore(inner_reader={self.inner_reader}, decoder={self.decoder})"
[docs] class SystemFileStore(FileStore[bytes]): """A FileStore that reads files directly from the file system."""
[docs] def __init__(self, base_dir: Optional[Union[EPath, str]] = None): """ Args: base_dir: The base directory to use for relative paths. If None, you should only pass absolute paths to __getitem__. """ self.base_dir = EPath(base_dir) if base_dir is not None else None
def __getitem__(self, key: str) -> tuple[bytes, SourceInfo]: # Construct the full path from the dataset path and the file key if self.base_dir is None: file_path = EPath(key) else: file_path = self.base_dir / key # Read and return the file contents as bytes with file_path.open("rb") as f: data = f.read() return data, SourceInfo( dataset_path=self.base_dir, index=None, shard_name=None, file_names=(key,), )
[docs] def get_path(self) -> str: """Returns the path to the dataset.""" return str(self.base_dir)
def __str__(self): return f"SystemFileStore(base_dir={self.base_dir})"
class WebdatasetFileStore(SqliteITarEntryReader, FileStore[bytes]): """This dataset will directly read files from the dataset tar files from a prepared energon dataset.""" def __init__( self, dataset_path: EPath, ): super().__init__(base_path=dataset_path, key_is_full_entryname=True) def get_path(self) -> str: return str(self.base_path)