Source code for megatron.energon.cache.file_cache_pool

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import hashlib
import os
import pickle
import random
import string
import tempfile
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, TypeVar

from megatron.energon.cache.base import CachePool, FileStore, Lazy
from megatron.energon.cache.file_store import DecodeFileStore
from megatron.energon.edataclass import edataclass
from megatron.energon.fork_hook import ForkMixin
from megatron.energon.source_info import SourceInfo, add_source_info

T = TypeVar("T")


[docs] @edataclass class FileCacheLazy(Lazy[T]): """ Represents a reference to a background prefetch. """ # The cache pool that we're using. pool: "FileStoreCachePool" # The entry in the cache pool that we're using. entry: "_PendingTask" # If get() was called, this will be the data (uncached). _data: Optional[tuple[T, SourceInfo]] = None
[docs] def get(self, sample: Any = None) -> T: """ Returns the data and adds the source info to the sample. If the background job hasn't started, we cancel it, do a direct read, and remove ourselves from the pool's references. Otherwise, we wait for the job to finish, read from cache, and remove ourselves. """ if self._data is not None: return self._data self._data = self.pool._get_data(self.ds, self.fname, self.entry) assert self._data is not None add_source_info(sample, self._data[1]) return self._data[0]
def __hash__(self) -> int: """Allows usage in sets and dicts as key.""" return hash((id(self.ds), self.fname)) def __eq__(self, other: Any) -> bool: """Allows usage in sets and dicts as key. Compares the data source and the filename.""" if not isinstance(other, Lazy): return False return self.ds is other.ds and self.fname == other.fname def __del__(self): if self._data is None: with self.pool._lock: # Data was never fetched, still decrement refcount to delete the cache entry self.pool._decrement_refcount_and_cleanup((self.ds.get_path(), self.fname))
@edataclass class _PendingTask: """Dataclass for storing a pending background task""" # The dataset that we're caching. ds: FileStore # The file name that we're caching. fname: str # The future for the background task that sends the data to the cache. send_to_cache_future: Future # The number of references to the cache entry. refcount: int = 1 # The size of the data to be cached. data_size: int = 0 # Whether the data is required now, i.e. a reading thread is waiting for it. require_data_now: bool = False # The path to the cache file. cache_path: Optional[Path] = None # The source info for the data. source_info: Optional[SourceInfo] = None
[docs] class FileStoreCachePool(CachePool, ForkMixin): """ Manages a thread pool to pre-fetch data onto an SSD cache. Each (ds, fname) has one Future (one read). Multiple requests share that same future. We track usage with a refcount. To avoid multi-process collisions, we generate a random subfolder for each instance. """ cache_dir: Path max_cache_size: int max_cache_count: int current_cache_size: int current_cache_count: int method: Literal["raw", "pickle"] # Thread pool for out-caching tasks _worker_pool: Optional[ThreadPoolExecutor] = None # (ds.path, fname) -> PendingTask _pending_tasks: Dict[Tuple[str, str], _PendingTask] # Lock for all shared structures _lock: threading.Lock # Condition variable to signal when cache space is available _cache_space_available: threading.Condition # Whether the pool is shutting down _shutting_down: bool = False
[docs] def __init__( self, *, parent_cache_dir: Optional[Path] = None, num_workers: int = 8, max_cache_size_gbytes: float = 1024, max_cache_count: int = 10_000_000, method: Literal["raw", "pickle"] = "raw", ): """ Initialize the cache pool. Args: parent_cache_dir: The parent directory for the cache. num_workers: The number of worker threads to use for copying the data to the cache for lazy loading. max_cache_size_gbytes: The maximum size of the cache in gigabytes. If the cache exceeds this size, the prefetching will wait until the cache is below this size. max_cache_count: The maximum number of files in the cache. If the cache exceeds this number, the prefetching will wait until the cache is below this number. method: The method to use for caching. "raw" store the non-decoded raw data. "pickle": first decode the data and then store the pickled data. """ super().__init__() # If no parent directory is given, create a temp directory if parent_cache_dir is None: parent_cache_dir = Path(tempfile.gettempdir()) self.parent_cache_dir = parent_cache_dir self.num_workers = num_workers # Initialize the cache pool (process volatile fields) self.__after_fork__(initial=True) self.method = method # We'll store _pending_tasks in the form: # (ds.path, fname) -> PendingTask self._pending_tasks = {} # Cache size management self.max_cache_size = int(max_cache_size_gbytes * (1024**3)) self.max_cache_count = max_cache_count self.current_cache_size = 0 self.current_cache_count = 0 # A lock to protect all shared structures self._lock = threading.Lock() # Condition variable to signal when cache space is available self._cache_space_available = threading.Condition(self._lock)
[docs] def get(self, ds: FileStore, fname: str, sample: Any = None) -> Any: """ Synchronous read from the dataset (no cache usage). """ return ds.get(fname, sample)
def _get_data(self, ds: FileStore, fname: str, entry: _PendingTask) -> tuple[Any, SourceInfo]: """ Get the data for a given file from the cache and purge cache if no references are left. * If the cache-out is complete, read from cache. * If the cache-out is currently prefetching the data to local storage, wait until it's done. * If the cache-out job is waiting for space, skip the cache and do a direct read. * If the cache-out job is queued for caching, cancel and do a direct read. * If the cache-out job failed, raise through and keep for other references. * If the cache-out job is cancelled, requeue if there are other references waiting for it. """ result: tuple[Any, SourceInfo] with self._lock: try: # Attempt to cancel if the job hasn't started if entry.send_to_cache_future.cancel(): was_cached = False try: # Cancelled => job never ran. We'll do a direct read. result = ds[fname] finally: # Decrement refcount self._decrement_refcount_and_cleanup(key=(ds.get_path(), fname)) else: # Future is already running or done. # Release the lock so the background job can proceed, # then reacquire it after waiting. Otherwise we might block the worker. entry.require_data_now = True self._cache_space_available.notify_all() self._lock.release() # If the job failed, let's keep the exception for other references. was_cached = True try: # Can raise exception if job failed was_cached = entry.send_to_cache_future.result() if was_cached: # The job is complete; read from cache result = self._read_from_cache(entry) else: # The job failed, so we'll do a direct decode result = ds[fname] finally: self._lock.acquire() entry.require_data_now = False # Decrement refcount self._decrement_refcount_and_cleanup(key=(ds.get_path(), fname)) finally: if entry.refcount > 0 and not was_cached: # TODO: Could write to cache here, data is already fetched. # Write the result to the cache # Requeue the job, there is another reference to the cache entry entry.send_to_cache_future = self._worker_pool.submit( self._cache_out_task, ds, fname, entry ) return result def _cache_out_task(self, ds: FileStore, fname: str, entry: _PendingTask) -> bool: with self._lock: if self._shutting_down: return False # Perform the data read if self.method == "raw": if isinstance(ds, DecodeFileStore): data, entry.source_info = ds.inner_reader[fname] else: data, entry.source_info = ds[fname] elif self.method == "pickle": data, entry.source_info = ds[fname] data = pickle.dumps(data) else: raise ValueError(f"Invalid method: {self.method}") # Wait until there's enough space in the cache with self._lock: entry.data_size = file_size = len(data) while ( self.current_cache_count + 1 > self.max_cache_count or self.current_cache_size + entry.data_size > self.max_cache_size ): # Release the lock and wait for notification self._cache_space_available.wait() if entry.require_data_now or self._shutting_down: # At least one reference requires the data now, stop waiting for space and exit immediately return False # Reserve the space self.current_cache_size += file_size self.current_cache_count += 1 if self._shutting_down or entry.refcount <= 0: # No more references to this background job, don't write to cache return False try: assert entry.cache_path is None, ( f"cache_path should be None, but is {entry.cache_path!r}" ) # Write to cache cache_path = self._make_cache_path(ds, fname) self._write_to_cache(cache_path, data) except: with self._lock: # Revert the space reservation self.current_cache_size -= file_size self.current_cache_count -= 1 self._cache_space_available.notify_all() raise else: with self._lock: entry.cache_path = cache_path # print( # f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Wrote to cache {cache_path} (rc={entry.refcount}, size={file_size}, name={fname})\n", # end="", # ) # Data is cached now, return True return True
[docs] def get_lazy(self, ds: FileStore, fname: str) -> FileCacheLazy: """ Schedule a background pre-fetch. If multiple calls come in for the same (ds, fname), they'll share the same Future and increment reference counts. """ key = (ds.get_path(), fname) with self._lock: if self._shutting_down: raise RuntimeError("Cache pool is already shutting down") entry = self._pending_tasks.get(key) if entry: # Already have a background task for this (ds, fname) entry.refcount += 1 else: # Create a new background task entry = _PendingTask( ds=ds, fname=fname, send_to_cache_future=None, ) self._pending_tasks[key] = entry entry.send_to_cache_future = self._worker_pool.submit( self._cache_out_task, ds, fname, entry ) return FileCacheLazy(ds=ds, fname=fname, pool=self, entry=entry)
[docs] def close(self) -> None: """ Shutdown the pool, wait for tasks, and clear our structures. """ with self._lock: self._shutting_down = True for entry in self._pending_tasks.values(): entry.send_to_cache_future.cancel() self._cache_space_available.notify_all() self._worker_pool.shutdown(wait=True) with self._lock: self._pending_tasks.clear()
def _decrement_refcount_and_cleanup(self, key: Tuple[FileStore, str]) -> None: """ Decrement the reference count in `_pending_tasks`. If it hits zero, remove the entry. Optionally remove the file if so. Assumes the caller holds `self._lock`. """ entry = self._pending_tasks.get(key) if not entry: # Already cleaned up return entry.refcount -= 1 if entry.refcount <= 0: # No more references to this background job del self._pending_tasks[key] self._remove_cached_file(entry) assert entry.refcount == 0, f"refcount should be 0: {entry.refcount}" def _make_cache_path(self, ds: FileStore, fname: str) -> Path: # This is safe, because the parent cache dir is unique per instance. ds_hash = hashlib.md5(ds.get_path().encode("utf-8")).hexdigest() fn_hash = hashlib.md5(fname.encode("utf-8")).hexdigest() # ds_hash = str(ds.get_path()).replace("/", "_") # fn_hash = fname.replace("/", "_") return self.cache_dir / f"{ds_hash}_{fn_hash}" def _write_to_cache(self, path: Path, data: bytes) -> None: path.parent.mkdir(parents=True, exist_ok=True) with open(path, "wb") as f: f.write(data) def _read_from_cache(self, entry: _PendingTask) -> tuple[Any, SourceInfo]: assert entry.source_info is not None, "source_info should have been set" with open(entry.cache_path, "rb") as f: if self.method == "raw": raw = f.read() if isinstance(entry.ds, DecodeFileStore): return entry.ds.decoder.decode(entry.fname, raw), entry.source_info else: return raw, entry.source_info else: return pickle.load(f), entry.source_info def _remove_cached_file(self, entry: _PendingTask) -> None: """ Removes a file from disk and updates size counters. Assumes the caller holds `self._lock`. """ if entry.cache_path is None: return if not entry.cache_path.exists(): return try: # print( # f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Removing cached file {entry.cache_path} (rc={entry.refcount})\n", # end="", # ) entry.cache_path.unlink() except OSError: pass entry.cache_path = None if entry.data_size > 0: self.current_cache_size -= entry.data_size self.current_cache_count -= 1 # Notify waiting threads that space is now available self._cache_space_available.notify_all() def __before_fork__(self): # Ensure the worker pool is shutdown before the fork assert len(self._pending_tasks) == 0, "Pending tasks should be empty before fork" # print( # f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Before fork for oid={id(self)} random_suffix={self.cache_dir.name!r}\n", # end="", # ) self._worker_pool.shutdown(wait=True) self._worker_pool = None def __after_in_child_fork__(self): self.__after_fork__() def __after_in_parent_fork__(self): self.__after_fork__() def __after_fork__(self, initial: bool = False): random_suffix = "".join( random.Random(os.getpid() ^ random.randint(0, 2**32)).choices( string.ascii_lowercase + string.digits, k=16 ) ) assert self._worker_pool is None self._worker_pool = ThreadPoolExecutor( max_workers=self.num_workers, thread_name_prefix="CacheWorker" ) # Create a random subdirectory name to avoid collisions with other processes # As the global random generator is cloned across processes, we need to use a process-specific seed self.cache_dir = (self.parent_cache_dir / f"cache_{random_suffix}").resolve() self.cache_dir.mkdir(parents=True, exist_ok=True) # if initial: # print( # f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Init oid={id(self)} random_suffix={random_suffix!r}\n", # end="", # ) # else: # print( # f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: After fork for pid={os.getpid()} oid={id(self)} random_suffix={random_suffix!r}\n", # end="", # ) def __str__(self): return f"FileStoreCachePool(cache_dir={self.cache_dir}, max_cache_size={self.max_cache_size}, max_cache_count={self.max_cache_count}, method={self.method}, current_cache_size={self.current_cache_size}, current_cache_count={self.current_cache_count})"