Source code for nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Storage writer for PyT Distributed format allowing asynchronous save. """
import dataclasses
import hashlib
import inspect
import logging
import os

# Issue: [B403:blacklist] Consider possible security implications associated with pickle module.
# Severity: Low   Confidence: High
# CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html)
# More Info: https://bandit.readthedocs.io/en/1.8.3/blacklists/blacklist_imports.html#b403-import-pickle
import pickle  # nosec
import queue
import threading
from functools import partial
from heapq import heappop, heappush
from itertools import chain
from operator import itemgetter
from time import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
from torch.distributed.checkpoint.api import WRAPPED_EXCEPTION, _wrap_exception
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
from torch.distributed.checkpoint.metadata import Metadata

try:
    from torch.distributed.checkpoint.filesystem import _StorageWriterTransforms
except ImportError:
    _StorageWriterTransforms = Any

from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future

try:
    import psutil

    HAVE_PSUTIL = True
except ImportError:
    HAVE_PSUTIL = False

from ..utils import _disable_gc
from .core import PersistentAsyncCaller

logger = logging.getLogger(__name__)

WriteBucket = Tuple[str, str, Tuple[list, list]]  # represents writes to a single file

_results_queue = None


[docs] class ConsistentDataIdentifier: """Identifier for consistent data structure stored in worker cache. This allows passing a lightweight identifier instead of pickling the entire data structure (which includes IPC handles) across process boundaries. """ def __init__(self, key: str): self.key = key
def _compute_data_structure_key_from_plan(items: List[WriteItem]) -> str: """Compute a hash key based on plan items only (no data resolution needed). This creates a deterministic key from plan metadata that's available without resolving the actual tensor data. Args: items: List of WriteItem from the plan Returns: Hex-digest string key representing the data structure """ structure_info = [] for item in items: # Include item metadata that defines the structure item_info = ( item.index.fqn, # Fully qualified name item.type, # WriteItemType (BYTE_IO or TENSOR) ) # Include metadata from plan (available without resolving data) if item.tensor_data is not None: # Use tensor metadata from the plan data_info = ( tuple(item.tensor_data.chunk.sizes), # Tensor chunk shape str(item.tensor_data.properties.dtype), # Data type ) else: # For non-tensor data (BYTE_IO), use placeholder data_info = (("BYTE_IO",), "BYTE_IO") structure_info.append((item_info, data_info)) # Use SHA-256 for collision resistance and cross-process stability # (Python's built-in hash() is randomized per-process and collision-prone) return hashlib.sha256(str(structure_info).encode()).hexdigest()
[docs] @_disable_gc() def get_write_results_queue(mp_mode: str = 'spawn') -> mp.Queue: """Get or create a multiprocessing queue for write results. Args: mp_mode (str): Multiprocessing context mode. Defaults to 'spawn'. Returns: mp.Queue: Queue for collecting write results. """ global _results_queue if _results_queue is None: ctx = mp.get_context(mp_mode) _results_queue = ctx.Manager().Queue() return _results_queue
[docs] class FileSystemWriterAsync(FileSystemWriter): """ Async-enabled implementation of FileSystemWriter using file I/O. This class does not spawn the async process itself but relies on an external async mechanism. **Flow:** 1. Call `write_data` 2. Externally start an async process with `get_save_function_and_args` and its arguments. 3. The async function calls `write_preloaded_data_multithread` (threads) or `write_preloaded_data_multiproc` (processes) across multiple workers. 4. Once saving is finalized on all ranks, call `super().finish` with the results stored in `self.writer_result`. **Note:** Step (3) can also be executed synchronously. Currently, it is assumed that a separate writer is created for each ckpt save (intermediate state is stored as writer attributes). """ # Class-level cache to track identifiers that have been sent to worker across instances _cached_identifiers: set = set() def __init__( self, path: Union[str, os.PathLike], *args, separation_hint: Optional[str] = None, use_msc: bool = False, is_multiproc_io: bool = False, use_cached_data_structure: bool = False, **kwargs, ): self.checkpoint_dir = path self.use_msc = use_msc self.open_file = kwargs.pop("open_file", open) # for overriding in tests super().__init__(path, *args, **kwargs) if not self.single_file_per_rank: raise NotImplementedError( 'single_file_per_rank flag not supported for FileSystemWriterAsync' ) self.can_run_decentralized_global_plan: bool = True # Intermediate state between preparation and finalization self.has_data_to_write: bool = False self.results_queue: Optional[mp.Queue] = None self.separation_hint = separation_hint self.use_cached_data_structure = use_cached_data_structure self.consistent_data_identifier: Optional[ConsistentDataIdentifier] = None # When this flag is True, the FileWriter can create multiple child processes # to parallelize File IO in the background async checkpoint process. # Setting this flag to False (default) uses multi-threading to parallelize File IO. # Note: multi-proc IO requires is_daemon=False on PersistentAsyncCaller (AsyncCallsQueue), # whereas the default multithreaded IO is compatible with is_daemon=True (the default). self.is_multi_proc_io = is_multiproc_io
[docs] def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None: """ First stage of async saving. Resolve data and store in compact format. Separates data into GPU tensors (potentially cacheable), CPU tensors (always fresh), and ByteIO (always fresh). Bucket creation is deferred to `preload_tensors` so that it can run in the persistent worker process and take advantage of the data cache. Args: plan (SavePlan): save plan generated by the PyT Distributed compatible planner planner (SavePlanner): save planner used to resolve the bytes and tensor data Returns: None, but stores the resolved plan data in instance attributes """ start = time() logger.debug(f"thread_count: {self.thread_count}, time: {start}") if self.separation_hint: assert ( self.thread_count > 1 ), "thread_count must be at least 2 if separation_hint is provided" def _clone_or_dequantize_if_needed(ten: torch.Tensor): """Clone if we detect incontiguous storage for CPU tensors. Makes sure we perform a `clone` only if we detect incontiguous storage, so that we don't blow up host memory unnecessarily. For GPU tensors, returns as-is since they'll be moved to CPU in preload_tensors. Returns: (tensor, was_dequantized): the processed tensor and a bool indicating whether dequantize() was called. Tracked explicitly because some frameworks (e.g. TransformerEngine MXFP8) use a fake bfloat16 dtype for quantized tensors, making dtype-based detection unreliable. """ ten = ten.detach() if ten.device.type != "cpu": # We call ``dequantize`` if we detect a quantized tensor on GPU. # This is a workaround to avoid the issue of quantized tensors not being supported by the async writer. if ten.device.type == "cuda" and "dequantize" in type(ten).__dict__: ten = ten.dequantize() # GPU tensors will be moved to CPU in preload_tensors return ten, True # GPU tensors will be moved to CPU in preload_tensors return ten, False # For CPU tensors, clone if they are views to ensure contiguous storage is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize return (ten.clone() if is_view else ten), False def resolve_data(items): resolved = [] dequantized_flags = [] for item in items: data = planner.resolve_data(item) # Apply cloning/dequantize logic during resolution if isinstance(data, torch.Tensor): data, was_dequantized = _clone_or_dequantize_if_needed(data) else: was_dequantized = False resolved.append(data) dequantized_flags.append(was_dequantized) return resolved, dequantized_flags # Separate items by type: only GPU tensors can be cached via IPC # CPU tensors and ByteIO must be resolved fresh (cannot use IPC) tensor_items = [item for item in plan.items if item.type != WriteItemType.BYTE_IO] byte_io_items = [item for item in plan.items if item.type == WriteItemType.BYTE_IO] # Helper to separate resolved tensors into cacheable (GPU) vs uncached buckets. # Uncached tensors include: CPU tensors and dequantized GPU tensors. # Dequantized tensors are tracked explicitly via dequantized_flags because # some frameworks (e.g. TransformerEngine MXFP8) report bfloat16 as the dtype # for quantized tensors, making dtype-based detection unreliable. def separate_cacheable(items, resolved_data, dequantized_flags): """Separate tensor items into IPC-cacheable (GPU) and uncached categories. GPU tensors that were not dequantized are cacheable via IPC. CPU tensors and dequantized GPU tensors are always passed fresh and never cached. """ gpu_items, gpu_data = [], [] uncached_items, uncached_data = [], [] for item, data, was_dequantized in zip(items, resolved_data, dequantized_flags): if ( isinstance(data, torch.Tensor) and data.device.type == "cpu" ) or was_dequantized: uncached_items.append(item) uncached_data.append(data) else: gpu_items.append(item) gpu_data.append(data) return (gpu_items, gpu_data), (uncached_items, uncached_data) # Handle GPU tensor caching (only GPU tensors can benefit from IPC) # Uncached tensors (CPU or dequantized) are always resolved fresh if self.use_cached_data_structure and tensor_items: key = _compute_data_structure_key_from_plan(tensor_items) cache_exists = key in FileSystemWriterAsync._cached_identifiers # Always resolve tensors to separate uncached tensors (which can't be cached) resolved_tensors, dequantized_flags = resolve_data(tensor_items) (gpu_items, gpu_data), (uncached_items, uncached_data) = separate_cacheable( tensor_items, resolved_tensors, dequantized_flags ) if cache_exists: # Reuse cached GPU tensors from worker self.consistent_data_identifier = ConsistentDataIdentifier(key) self.cached_tensor_data = None # Signal to reuse cached data logger.debug( f"Reusing cached GPU tensors (key={key}), " f"resolved {len(uncached_items)} uncached tensors fresh" ) elif gpu_items: # First time caching - send GPU tensor data to worker self.consistent_data_identifier = ConsistentDataIdentifier(key) self.cached_tensor_data = (gpu_items, gpu_data) FileSystemWriterAsync._cached_identifiers.add(key) logger.debug( f"Caching {len(gpu_items)} GPU tensors (key={key}), " f"{len(uncached_items)} uncached tensors passed fresh" ) else: # No GPU tensors to cache; skip caching entirely self.consistent_data_identifier = None self.cached_tensor_data = None logger.debug( f"No GPU tensors to cache (key={key}), " f"{len(uncached_items)} uncached tensors passed fresh" ) # Uncached tensors are always passed fresh (never cached) self.uncached_tensor_data = (uncached_items, uncached_data) if uncached_items else None else: # No caching - resolve and separate all tensors self.consistent_data_identifier = None if tensor_items: resolved_tensors, dequantized_flags = resolve_data(tensor_items) (gpu_items, gpu_data), (uncached_items, uncached_data) = separate_cacheable( tensor_items, resolved_tensors, dequantized_flags ) self.cached_tensor_data = (gpu_items, gpu_data) if gpu_items else None self.uncached_tensor_data = ( (uncached_items, uncached_data) if uncached_items else None ) else: self.cached_tensor_data = None self.uncached_tensor_data = None # Always resolve ByteIO fresh (cannot use IPC) self.byte_io_data = ( (byte_io_items, resolve_data(byte_io_items)[0]) if byte_io_items else None ) self.storage_plan = plan.storage_data # Setup results queue if there's data to write self.has_data_to_write = len(plan.items) > 0 self.results_queue = get_write_results_queue() if self.has_data_to_write else None end = time() logger.debug(f"prepare_write_data, time: {end - start}")
[docs] def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Callable], List]: """ Get function that saves the data to storage along with its arguments. Allows the external caller to apply the save function synchronously or asynchronously. Returns: None (if there is nothing to write on this rank) or a tuple of: 1) the function that saves the data. 2) the function that stages the GPU tensors to a destination for async checkpointing. This function should be self-contained. 3) arguments to that function in 1). """ if not self.has_data_to_write: return None, None, [] if self.use_msc: import multistorageclient as msc open_file = msc.open else: open_file = self.open_file transform_list = [self.transforms] if hasattr(self, 'transforms') else [] # Format: (identifier, (separation_hint, cached_tensor_data, # uncached_tensor_data, byte_io_data, thread_count, storage_plan)) # identifier is None when caching is disabled # uncached_tensor_data is always passed fresh (like ByteIO), never cached data_to_pass = ( self.consistent_data_identifier, ( self.separation_hint, self.cached_tensor_data, self.uncached_tensor_data, self.byte_io_data, self.thread_count, self.storage_plan, ), ) # Select write function based on IO mode if self.is_multi_proc_io: write_func = partial( self.write_preloaded_data_multiproc, transform_list, self.use_msc, open_file ) else: write_func = partial( self.write_preloaded_data_multithread, transform_list, self.use_msc, open_file ) return ( write_func, partial(self.preload_tensors, (str(self.checkpoint_dir), data_to_pass), True), [torch.distributed.get_rank(), None, self.results_queue], )
[docs] @staticmethod def preload_tensors(resolved_plan_data: Tuple, non_blocking=True) -> List[WriteBucket]: """ Creates write_buckets and preloads tensors to host memory. This runs in the persistent worker process. Bucket creation is done here (not in prepare_write_data) so that cached GPU tensor data stored in the worker process can be retrieved and reused without re-pickling. Args: resolved_plan_data (Tuple): Tuple containing (checkpoint_dir, (identifier, data_structure)) where: - identifier: ConsistentDataIdentifier (caching) or None - data_structure: (separation_hint, cached_tensor_data, uncached_tensor_data, byte_io_data, thread_count, storage_plan) non_blocking (bool, optional): Enable pinned D2H memcpy. Default is True. Returns: List[WriteBucket]: List of write buckets with tensors moved to CPU """ start = time() logger = logging.getLogger(__name__) checkpoint_dir, data_or_identifier = resolved_plan_data # Helper to combine GPU tensor, uncached tensor, and ByteIO data def combine_data(gpu_tensor_data, uncached_tensor_data, byte_io_data): items, resolved = [], [] for data in [gpu_tensor_data, uncached_tensor_data, byte_io_data]: if data: items.extend(data[0]) resolved.extend(data[1]) return items, resolved # Parse data structure: (identifier, (separation_hint, cached_tensor_data, # uncached_tensor_data, byte_io_data, thread_count, storage_plan)) # identifier is None when disabled, ConsistentDataIdentifier when enabled identifier, data_structure = data_or_identifier ( separation_hint, cached_tensor_data, uncached_tensor_data, byte_io_data, thread_count, storage_plan, ) = data_structure if isinstance(identifier, ConsistentDataIdentifier): # Caching enabled: get or cache GPU tensor data in the worker process # Uncached tensors (CPU or dequantized) are NOT cached (treated like ByteIO) key = identifier.key if cached_tensor_data is not None: PersistentAsyncCaller._worker_data_cache[key] = cached_tensor_data logger.debug(f"Worker cached GPU tensors (key={key})") elif key in PersistentAsyncCaller._worker_data_cache: cached_tensor_data = PersistentAsyncCaller._worker_data_cache[key] logger.debug(f"Worker retrieved cached GPU tensors (key={key})") else: raise RuntimeError(f"Worker cache miss for key {key}. Worker may have restarted.") # else: identifier is None, no caching needed items, resolved_data = combine_data(cached_tensor_data, uncached_tensor_data, byte_io_data) logger.debug(f"preload_tensors: thread_count: {thread_count}, time: {start}") # Create buckets from items bins = thread_count // 2 if separation_hint is not None else thread_count item_buckets = _split_by_size_and_type(bins, items) logger.debug(f"preload_tensors: bucket_prep, time: {time() - start}") # Create a mapping from items to resolved data item_to_data = {id(item): data for item, data in zip(items, resolved_data)} file_count = 0 def gen_file(prefix=""): nonlocal file_count file_name = f"{prefix}{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" file_count += 1 return file_name # Build write_buckets with items grouped by file, assigning one per worker write_buckets = [] for group_name, group_buckets in _split_by_separation_hint( item_buckets, separation_hint ).items(): for bucket in group_buckets: bytes_data = [] tensor_data = [] for item in bucket: data = item_to_data[id(item)] if item.type == WriteItemType.BYTE_IO: bytes_data.append((item, data)) else: # Tensor data (GPU or CPU) - already cloned if needed tensor_data.append((item, data)) if len(bytes_data) > 0 or len(tensor_data) > 0: file_name = gen_file(prefix=group_name) write_buckets.append( ( os.path.join(checkpoint_dir, file_name), file_name, (bytes_data, tensor_data), ) ) # Now move GPU tensors to CPU (CPU tensors are already on CPU) result: List[WriteBucket] = [] for bucket in write_buckets: bucket_path, bucket_key, bucket_data = bucket bytes_data, tensor_data = bucket_data tensor_list = [] for item, tensor in tensor_data: # Move to CPU if needed (no-op if already on CPU) tensor_list.append((item, tensor.to("cpu", non_blocking=non_blocking))) result.append((bucket_path, bucket_key, (bytes_data, tensor_list))) if non_blocking: torch.cuda.synchronize() end = time() logger.debug(f"preload_tensors: D2H and bucket creation, time: {end - start}") return result
@staticmethod def _initialize_write_execution(rank: int) -> Tuple[logging.Logger, float, dict]: """ Common initialization for write execution. Args: rank (int): training rank Returns: Tuple[logging.Logger, float, dict]: logger, start time, and initialized results dict """ logger = logging.getLogger(__name__) w_start = time() write_results_or_exc: Union[dict, Exception] = dict() return logger, w_start, write_results_or_exc @staticmethod def _build_worker_kwargs( worker_idx: int, write_bucket: WriteBucket, use_msc: bool, worker_type: str, **extra_kwargs ) -> dict: """ Build kwargs for worker (thread or process). Args: worker_idx (int): index of the worker write_bucket (WriteBucket): data to write use_msc (bool): flag to indicate use of multi storage client worker_type (str): 'thread' or 'proc' **extra_kwargs: additional worker-specific kwargs Returns: dict: kwargs for the worker """ idx_key = f'local_{worker_type}_idx' kwargs = { idx_key: worker_idx, 'write_bucket': write_bucket, 'use_fsync': True, } if use_msc: kwargs['use_msc'] = use_msc kwargs.update(extra_kwargs) return kwargs @staticmethod def _finalize_write_execution( global_results_queue: mp.Queue, write_results_or_exc: Union[dict, Exception], rank: int, w_start: float, worker_type: str, logger: logging.Logger, ) -> None: """ Common finalization for write execution. Args: global_results_queue (mp.Queue): queue to put results write_results_or_exc (Union[dict, Exception]): results or exception rank (int): training rank w_start (float): start time worker_type (str): 'MultiProc' or 'MultiThread' logger (logging.Logger): logger instance """ global_results_queue.put(write_results_or_exc) w_end = time() logger.debug( f"{worker_type} Background Async worker time to persist: {w_end - w_start} s for rank={rank}" )
[docs] @staticmethod @_disable_gc() def write_preloaded_data_multiproc( transform_list: List[_StorageWriterTransforms], use_msc: bool, open_file: Callable, rank: int, write_buckets: List[WriteBucket], global_results_queue: mp.Queue, ) -> None: """ Performs saving data to storage with multiple processes. Starts predefined number of processes and uses 2 queues to make sure the results are complete: - local_results_queue - to send the actual results - count_queue - small queue to mark worker as completed Using just one queue disallowed proper exception handling. This method is meant to be run in a forked subprocess. Triggering GC during execution leads to CUDA errors (cleaning up tensors owned by the parent process). To prevent this, we disable the GC explicitly for this function with _disable_gc. Note: requires is_daemon=False on the PersistentAsyncCaller, because daemon processes cannot spawn child processes. Args: transform_list (List[_StorageWriterTransforms]): streaming transforms list use_msc (bool): flag to indicate use of multi storage client open_file (Callable): file open callable rank (int): training rank write_buckets (List[WriteBucket]): write plan global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]] (or an Exception) from parallel write processes to the main training process Returns: None """ logger, w_start, write_results_or_exc = FileSystemWriterAsync._initialize_write_execution( rank ) ctx = mp.get_context('fork') local_results_queue = ctx.Queue() count_queue = ctx.JoinableQueue() p_list = [] for i, write_bucket in enumerate(write_buckets): try: current_process = mp.current_process() if current_process.daemon: err_msg = "Invalid Setup! User cannot establish a daemon Async worker and then use Multi-Proc File IO." logger.error(err_msg) raise RuntimeError(err_msg) count_queue.put(i) kwargs = FileSystemWriterAsync._build_worker_kwargs( worker_idx=i, write_bucket=write_bucket, use_msc=use_msc, worker_type='proc', results_queue=local_results_queue, count_queue=count_queue, ) p_list.append( ctx.Process( target=partial( FileSystemWriterAsync.write_preloaded_data_proc, transform_list, open_file, ), kwargs=kwargs, ) ) except Exception as e: err_msg = f'An error is caught while a proc {i} is created, error: {e}' logger.error(err_msg) write_results_or_exc = RuntimeError(err_msg) if not isinstance(write_results_or_exc, Exception): for p in p_list: p.start() logger.debug('FileSystemWriterAsync: collecting worker results...') # To make sure all nodes are completed count_queue.join() # At this point, all workers completed, so the queue should have exactly # `len(write_buckets)` items for proc_idx in range(len(write_buckets)): try: local_proc_idx, local_results_or_exc = local_results_queue.get() except queue.Empty: write_results_or_exc = RuntimeError( 'Unexpected empty `local_results_queue`' f' (got only {proc_idx}/{len(write_buckets)} items)' ) break else: if isinstance(local_results_or_exc, Exception): err_msg = ( f"Local process {local_proc_idx} encountered" f" an error: {local_results_or_exc}" ) logger.error(err_msg) write_results_or_exc = local_results_or_exc break assert isinstance(local_results_or_exc, list), type(local_results_or_exc) write_results_or_exc[local_proc_idx] = local_results_or_exc p_list[local_proc_idx].join() logger.debug('FileSystemWriterAsync: collected worker results successfully') FileSystemWriterAsync._finalize_write_execution( global_results_queue, write_results_or_exc, rank, w_start, "MultiProc", logger )
@staticmethod def _write_bucket_to_storage( transform_list: List[_StorageWriterTransforms], open_file: Callable, write_bucket: WriteBucket, use_fsync: bool, use_msc: bool, ) -> List[WriteResult]: """ Core logic for writing a bucket to storage. Args: transform_list (List[_StorageWriterTransforms]): streaming transforms list open_file (Callable): file open callable write_bucket (WriteBucket): data to write to storage use_fsync (bool): if True, calls os.fsync at the end of saving use_msc (bool): flag to indicate use of multi storage client Returns: List[WriteResult]: list of write results """ file_name, storage_key, (bytes_data, tensor_data) = write_bucket extra_kwargs = {} write_fn = _write_item if "serialization_format" in inspect.signature(_write_item).parameters: from torch.distributed.checkpoint.filesystem import SerializationFormat extra_kwargs['serialization_format'] = SerializationFormat.TORCH_SAVE if "transforms" in inspect.signature(_write_item).parameters: assert len(transform_list) <= 1 write_fn = partial(_write_item, *transform_list) local_results = [] with open_file(file_name, "wb") as stream: for write_item, data in bytes_data: local_results.append( write_fn(stream, data, write_item, storage_key, **extra_kwargs) ) for write_item, tensor in tensor_data: assert tensor.is_cpu local_results.append( write_fn(stream, tensor, write_item, storage_key, **extra_kwargs) ) if use_fsync: if use_msc: stream.fsync() else: os.fsync(stream.fileno()) return local_results
[docs] @staticmethod @_disable_gc() def write_preloaded_data_proc( transform_list: List[_StorageWriterTransforms], open_file: Callable, local_proc_idx: int, write_bucket: WriteBucket, results_queue: mp.SimpleQueue, count_queue: mp.JoinableQueue, use_fsync: bool, **kwargs, ) -> None: """ Performs actual data saving to storage (used by worker processes in multiproc mode). Args: local_proc_idx (int): index of a local process that performs writing write_bucket (WriteBucket): data to write to storage results_queue (mp.Queue): queue to return the write results to the proxy checkpoint process. count_queue (mp.JoinableQueue): queue to marks worker task as completed use_fsync (bool): if True, calls os.fsync at the end of saving Returns: None, the write results are put into the `results_queue` """ logger = logging.getLogger(__name__) logger.debug(f'{local_proc_idx} started') mem_before = _process_memory() use_msc = kwargs.get('use_msc', False) try: local_results = FileSystemWriterAsync._write_bucket_to_storage( transform_list, open_file, write_bucket, use_fsync, use_msc ) local_output = (local_proc_idx, local_results) except Exception as e: logger.debug(f'{local_proc_idx} failed') local_output = (local_proc_idx, e) results_queue.put(local_output) # Signal this process is done. count_queue.get() count_queue.task_done() mem_after = _process_memory() logger.debug( f"{local_proc_idx} consumed: {mem_after - mem_before}," f" before: {mem_before}, after: {mem_after}" )
[docs] @staticmethod @_disable_gc() def write_preloaded_data_multithread( transform_list: List[_StorageWriterTransforms], use_msc: bool, open_file: Callable, rank: int, write_buckets: List[WriteBucket], global_results_queue: mp.Queue, ) -> None: """ Performs saving data to storage with multiple threads. Uses threads (not processes) so that this can run safely inside a daemon process without spawning child processes. The last bucket runs on the calling thread to avoid thread creation overhead. Uses two queues for worker coordination: - local_results_queue - to collect write results from worker threads - count_queue - to signal worker completion (get + task_done / join). Triggering GC during execution can lead to CUDA errors when tensors are shared. To prevent this, we disable the GC explicitly for this function with _disable_gc. Args: transform_list (List[_StorageWriterTransforms]): streaming transforms list use_msc (bool): flag to indicate use of multi storage client for storage access open_file (Callable): file open callable rank (int): training rank write_buckets (List[WriteBucket]): write plan global_results_queue (mp.Queue): queue to send Dict[List[WriteResults]] (or an Exception) back to the main training process Returns: None """ logger = logging.getLogger(__name__) w_start = time() write_results_or_exc: Union[dict, Exception] = dict() local_results_queue: queue.Queue = queue.Queue() count_queue: queue.Queue = queue.Queue() thread_list: List[threading.Thread] = [] def check_local_output(local_results_or_exc, local_worker_idx): if isinstance(local_results_or_exc, Exception): err_msg = ( f"Local worker {local_worker_idx} encountered" f" an error: {local_results_or_exc}" ) logger.error(err_msg) raise local_results_or_exc for i, write_bucket in enumerate(write_buckets): try: kwargs = { "local_thread_idx": i, "write_bucket": write_bucket, "results_queue": local_results_queue, "count_queue": count_queue, "use_fsync": True, } if use_msc: kwargs["use_msc"] = use_msc # Parallel writers: spawn threads for all but the last bucket if i < len(write_buckets) - 1: count_queue.put(i) t = threading.Thread( target=partial( FileSystemWriterAsync.write_preloaded_data, transform_list, open_file ), kwargs=kwargs, ) thread_list.append(t) else: # Run last bucket on the calling thread (no thread overhead) kwargs['count_queue'] = None kwargs['results_queue'] = None logger.debug('FileSystemWriterAsync: main worker started') local_output = FileSystemWriterAsync.write_preloaded_data( transform_list, open_file, **kwargs ) if local_output is not None: logger.debug( 'FileSystemWriterAsync: main worker results successfully collected' ) check_local_output(local_output[1], local_output[0]) write_results_or_exc[local_output[0]] = local_output[1] except Exception as e: err_msg = f"An error is caught while starting worker {i}, error: {e}" logger.error(err_msg) write_results_or_exc = RuntimeError(err_msg) if not isinstance(write_results_or_exc, Exception) and len(thread_list) > 0: for t in thread_list: t.start() logger.debug("FileSystemWriterAsync: collecting worker results...") count_queue.join() for _ in range(len(write_buckets) - 1): try: local_thread_idx, local_results_or_exc = local_results_queue.get() except queue.Empty: write_results_or_exc = RuntimeError( "Unexpected empty `local_results_queue`" f" (expected {len(write_buckets) - 1} items)" ) break else: try: check_local_output(local_results_or_exc, local_thread_idx) except Exception as worker_exc: write_results_or_exc = worker_exc break write_results_or_exc[local_thread_idx] = local_results_or_exc for t in thread_list: t.join() logger.debug('FileSystemWriterAsync: collected worker results successfully') if isinstance(write_results_or_exc, dict) and len(write_results_or_exc) != len( write_buckets ): write_results_or_exc = RuntimeError( f"Incomplete write results: expected {len(write_buckets)} buckets," f" got {len(write_results_or_exc)}" ) global_results_queue.put(write_results_or_exc) w_end = time() logger.debug(f"{w_end}, rank: {rank}, write(sync,threads): {w_end - w_start}")
[docs] @staticmethod @_disable_gc() def write_preloaded_data( transform_list: List[_StorageWriterTransforms], open_file: Callable, local_thread_idx: int, write_bucket: WriteBucket, results_queue: Optional[queue.Queue], count_queue: Optional[queue.Queue], use_fsync: bool, **kwargs, ) -> Optional[Tuple[int, Union[List[WriteResult], Exception]]]: """ Performs actual data saving to storage (used by worker threads in multithread mode). Args: transform_list (List[_StorageWriterTransforms]): streaming transforms list open_file (Callable): file open callable local_thread_idx (int): index of the worker thread that performs writing write_bucket (WriteBucket): data to write to storage results_queue (queue.Queue): queue to return the write results. If None (main-thread worker), result is returned directly. count_queue (queue.Queue): queue to signal worker task completion. If None (main-thread worker), skipped. use_fsync (bool): if True, calls os.fsync at the end of saving Returns: None when running in a thread (results put in queue); result tuple when running as main-thread worker (results_queue is None) """ logger = logging.getLogger(__name__) logger.debug(f'{local_thread_idx} started') mem_before = _process_memory() use_msc = kwargs.get('use_msc', False) try: local_results = FileSystemWriterAsync._write_bucket_to_storage( transform_list, open_file, write_bucket, use_fsync, use_msc ) local_output = (local_thread_idx, local_results) except Exception as e: logger.debug(f'{local_thread_idx} failed with exception {e}') local_output = (local_thread_idx, e) if results_queue is not None: results_queue.put(local_output) if count_queue is not None: # Signal this thread is done. count_queue.get() count_queue.task_done() mem_after = _process_memory() logger.debug( f"{local_thread_idx} consumed: {mem_after - mem_before}," f" before: {mem_before}, after: {mem_after}" ) return local_output
[docs] def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]: """Write all items from ``plan``.""" raise NotImplementedError('write_data not implemented for FileSystemWriterAsync')
[docs] def retrieve_write_results(self) -> Union[List[WriteResult], WRAPPED_EXCEPTION]: """ Turn the latest dict including write results from `self.results_queue` into a single results lists. Includes error check. Returns (Union(List[WriteResult], WRAPPED_EXCEPTION): the list of write results from all local workers performing the save, or a WRAPPED_EXCEPTION if an exception was raised during the writing process. """ if self.results_queue is None: write_results_or_exc = {} else: try: write_results_or_exc = self.results_queue.get_nowait() except queue.Empty: return _wrap_exception(RuntimeError('results_queue should not be empty')) if isinstance(write_results_or_exc, Exception): # Worker failed — its data cache may have been lost (e.g. after a restart). # Drop any identifier we recorded as cached so the next save re-populates it. if self.consistent_data_identifier is not None: FileSystemWriterAsync._cached_identifiers.discard( self.consistent_data_identifier.key ) try: raise RuntimeError( f'Worker failure: {write_results_or_exc}' ) from write_results_or_exc except Exception as e: return _wrap_exception(e) write_results: dict = write_results_or_exc if self.has_data_to_write and len(write_results) == 0: return _wrap_exception( RuntimeError( 'Worker returned empty results despite having data to write.' ' This probably indicates a worker failure.' ) ) return list(chain.from_iterable(write_results.values()))
[docs] def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: """Instead of assigning indices by plan order, uses PyT rank (same outcome). Args: local_plan (SavePlan): local plan to turn to a global plan (without interactions with other ranks) Returns: SavePlan - locally transformed plan equivalent to the plan that would be created by the coordinator """ return dataclasses.replace( local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_") )
[docs] def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: """ Finish the checkpointing process. Args: metadata (Metadata): metadata to save results (List[List[WriteResult]]): results to save """ if self.use_msc: import multistorageclient as msc storage_md = dict() for wr_list in results: storage_md.update({wr.index: wr.storage_data for wr in wr_list}) metadata.storage_data = storage_md # storage_meta was introduced since PyTorch 2.4 if "storage_meta" in inspect.signature(Metadata).parameters: metadata.storage_meta = self.storage_meta() path = os.path.join(self.checkpoint_dir, ".metadata") with msc.open(path, "wb") as metadata_file: # Issue: [B301:blacklist] Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue. # Severity: Medium Confidence: High # CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html) # More Info: https://bandit.readthedocs.io/en/1.8.3/blacklists/blacklist_calls.html#b301-pickle pickle.dump(metadata, metadata_file) # nosec else: super().finish(metadata, results)
[docs] def prepare_local_plan(self, plan: SavePlan) -> SavePlan: """ Prepare the local plan for the checkpointing process. """ if self.use_msc: import multistorageclient as msc msc.os.makedirs(str(self.checkpoint_dir), exist_ok=True) else: super().prepare_local_plan(plan) return plan
@property def checkpoint_id(self) -> Union[str, os.PathLike]: """ return the checkpoint_id that will be used to save the checkpoint. """ return str(self.checkpoint_dir)
[docs] @classmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: """ Validate the checkpoint_id that will be used to save the checkpoint. This method is available in PyTorch 2.3 and above. """ checkpoint_id_str = str(checkpoint_id) if checkpoint_id_str.startswith("msc://"): return True if hasattr(FileSystemWriter, "validate_checkpoint_id"): return FileSystemWriter.validate_checkpoint_id(checkpoint_id) return False
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]: """ Splits write items according to item size into close to uniform bins. Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type, but with a fixed _item_size function. Args: bins (int): numbers of bins to split to items (List[WriteItem]): list of write items Returns (List[List[WriteItem]]): write items split to bins """ if bins == 1: return [items] bytes_items: List[WriteItem] = [] tensor_items: List[WriteItem] = [] for wi in items: container = bytes_items if wi.type == WriteItemType.BYTE_IO else tensor_items container.append(wi) buckets: List[List[WriteItem]] = [[] for _ in range(bins)] bucket_sizes = [0 for _ in range(bins)] # Assign bytes with a simple round-robin for i, item in enumerate(bytes_items): buckets[i % bins].append(item) # Sort tensor items by size in decreasing order once and store the size with item sized_tensors = [(item, _item_size(item)) for item in tensor_items] sized_tensors.sort(key=itemgetter(1), reverse=True) # Use a min heap for bin assignment # Store (total_size_of_bin, bin_index) tuples heap: List[Tuple[int, int]] = [(0, i) for i in range(bins)] # Assign tensors using heap for item, size in sized_tensors: total_bin_size, bin_idx = heappop(heap) buckets[bin_idx].append(item) heappush(heap, (total_bin_size + size, bin_idx)) return buckets def _split_by_separation_hint( buckets: List[List[WriteItem]], separation_hint: Optional[str] = None ) -> Dict[str, List[List[WriteItem]]]: """ Splits buckets into those whose keys begin with the separation_hint and those whose keys do not Args: buckets (List[List[WriteItem]]): buckets to split separation_hint (Optional[str]): optional prefix to split on Returns (Dict[str, List[List[WriteItem]]]): a dictionary mapping the prefix to the relevant buckets """ bins = len(buckets) buckets_with_separation_hint: Dict[str, List[List[WriteItem]]] = {} if separation_hint is not None: buckets_default: List[List[WriteItem]] = [[] for _ in range(bins)] buckets_hint: List[List[WriteItem]] = [[] for _ in range(bins)] for i in range(bins): for item in buckets[i]: if item.index.fqn.startswith(separation_hint): buckets_hint[i].append(item) else: buckets_default[i].append(item) buckets_with_separation_hint[""] = buckets_default buckets_with_separation_hint[separation_hint] = buckets_hint else: buckets_with_separation_hint[""] = buckets return buckets_with_separation_hint def _item_size(item: WriteItem) -> int: """ Calculates size (in bytes) of a single write item. Same as torch.distributed.checkpoint.filesystem._item_size, but fixes computing chunk size (with item.tensor_data.chunk.sizes) Args: item (WriteItem): write item to compute the size of Returns (int): size of an item in bytes """ size = 1 assert item.tensor_data is not None # can't use math.prod as PT needs to support older python for s in item.tensor_data.chunk.sizes: size *= s dtype = item.tensor_data.properties.dtype return size * torch._utils._element_size(dtype) def _process_memory() -> int: """ Get memory used by current process. Returns (int): memory used by current process """ if not HAVE_PSUTIL: raise RuntimeError("psutil is not installed, please install it with `pip install psutil`") process = psutil.Process(os.getpid()) mem_info = process.memory_info() return mem_info.rss