Source code for megatron.energon.cache.base

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC, abstractmethod
from typing import Any, Callable, Generic, TypeVar

from megatron.energon.edataclass import edataclass
from megatron.energon.source_info import SourceInfo, add_source_info

T = TypeVar("T")


[docs] class FileStore(Generic[T]): """Base type for a dataset that can be accessed randomly by sample key.""" @abstractmethod def __getitem__(self, key: str) -> tuple[T, SourceInfo]: """Returns the data for the given key.""" ...
[docs] def get(self, key: str, sample: Any = None) -> Any: """Returns the data for the given key and adds the source info to the sample.""" data, source_info = self[key] add_source_info(sample, source_info) return data
[docs] @abstractmethod def get_path(self) -> str: """Returns the path to the dataset.""" ...
[docs] @edataclass class Lazy(Generic[T]): """ Abstract base class for lazy references to data. """ ds: FileStore fname: str pool: "CachePool"
[docs] @abstractmethod def get(self, sample: Any = None) -> T: """ Get the lazy data now and adds the source info to the sample. """ ...
def __hash__(self) -> int: """Allows usage in sets and dicts as key.""" return hash((id(self.ds), self.fname)) def __eq__(self, other: Any) -> bool: """Allows usage in sets and dicts as key. Compares the data source and the filename.""" if not isinstance(other, Lazy): return False return self.ds is other.ds and self.fname == other.fname
[docs] @edataclass class MockLazy(Lazy[T]): """ Mock object, which can be used as a Lazy. Allows the user to set the function to retrieve the data. May be used to create a Lazy that is initialized from a function. """ ds: FileStore fname: str pool: "CachePool" get_fn: Callable[[str], T]
[docs] def __init__(self, fname: str, get_fn: Callable[[str], T]): """ Initialize the MockLazy object. Args: fname: The file name of the mock object (may be used by the user). get_fn: The function to retrieve/generate the data. """ self.ds = None self.fname = fname self.pool = None self.get_fn = get_fn
[docs] def get(self, sample: Any = None) -> T: """ Get the lazy data now and adds no source info to the sample. """ return self.get_fn(self.fname)
def __hash__(self) -> int: return hash((self.fname, self.get_fn)) def __eq__(self, other: Any) -> bool: if not isinstance(other, MockLazy): return False return self.fname == other.fname and self.get_fn == other.get_fn def __repr__(self) -> str: return f"MockLazy(fname={self.fname!r}, get_fn={self.get_fn!r})"
[docs] class CachePool(ABC): """ A cache pool allows to load the needed data in the background and access it later. The most important example being `FileStoreCachePool` which caches data on a local SSD disk. To request data, use `get_lazy` to get a `Lazy` object. Then, call `Lazy.get()` to get the data later on. """
[docs] @abstractmethod def get(self, ds: FileStore, fname: str, sample: Any = None) -> Any: """ Get the data for a given file and adds the source info to the sample. """ ...
[docs] @abstractmethod def get_lazy(self, ds: FileStore, fname: str) -> Lazy: """ Get a lazy reference to the data for a given file. """ ...
[docs] @abstractmethod def close(self) -> None: """ Close the cache pool. """ ...
[docs] class FileStoreDecoder(ABC): """ Abstract base class for decoders. """
[docs] @abstractmethod def decode(self, fname: str, data: bytes) -> Any: """ Decode the specified file (i.e. path/key.ext). The extension is used to select the decoder. Args: fname: The file name of the file to decode. raw: The raw bytes of the file to decode. Returns: The decoded field's data. """ ...