Source code for nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Storage writer for PyT Distributed format allowing asynchronous save. """
import dataclasses
import inspect
import logging
import os

# Issue: [B403:blacklist] Consider possible security implications associated with pickle module.
# Severity: Low   Confidence: High
# CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html)
# More Info: https://bandit.readthedocs.io/en/1.8.3/blacklists/blacklist_imports.html#b403-import-pickle
import pickle  # nosec
import queue
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from heapq import heappop, heappush
from itertools import chain
from operator import itemgetter
from pathlib import Path
from time import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import psutil
import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
from torch.distributed.checkpoint.api import WRAPPED_EXCEPTION, _wrap_exception
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
from torch.distributed.checkpoint.metadata import Metadata

try:
    from torch.distributed.checkpoint.filesystem import _StorageWriterTransforms
except ImportError:
    _StorageWriterTransforms = Any

from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future

from ..utils import _disable_gc

logger = logging.getLogger(__name__)

WriteBucket = Tuple[Path, str, Tuple[list, list]]  # represents writes to a single file

_results_queue = None


def _get_write_results_queue():
    global _results_queue
    if _results_queue is None:
        ctx = mp.get_context('spawn')
        with _disable_gc():
            _results_queue = ctx.Manager().Queue()
    return _results_queue


[docs] class FileSystemWriterAsync(FileSystemWriter): """ Async-enabled implementation of FileSystemWriter using file I/O. This class does not spawn the async process itself but relies on an external async mechanism. **Flow:** 1. Call `write_data` 2. Externally start an async process with `get_save_function_and_args` and its arguments. 3. The async function `writer_proxy_func` calls `write_preloaded_data` across multiple processes. 4. Once saving is finalized on all ranks, call `super().finish` with the results stored in `self.writer_result`. **Note:** Step (3) can also be executed synchronously. Currently, it is assumed that a separate writer is created for each ckpt save (intermediate state is stored as writer attributes). """ def __init__( self, path: Union[str, os.PathLike], *args, separation_hint: Optional[str] = None, is_multiproc_io: bool = True, **kwargs, ): self.checkpoint_dir = path self.use_msc = kwargs.pop("use_msc", False) self.open_file = kwargs.pop("open_file", open) # for overriding in tests super().__init__(path, *args, **kwargs) if not self.single_file_per_rank: raise NotImplementedError( 'single_file_per_rank flag not supported for FileSystemWriterAsync' ) self.can_run_decentralized_global_plan: bool = True # Intermediate state between preparation and finalization self.write_buckets: Optional[List[WriteBucket]] = None self.results_queue: Optional[mp.Queue] = None self.separation_hint = separation_hint # When this flag is True, the FileWriter can create multiple child processes # to parallelize File IO in the background async checkpoint process. # Setting this flag to False implies we resort to multi-threading to parallelize File IO. self.is_multi_proc_io = is_multiproc_io
[docs] def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None: """ First stage of async saving. Copy data to CPU and plan the local saving. Args: plan (SavePlan): save plan generated by the PyT Distributed compatible planner planner (SavePlanner): save planner used to resolve the bytes and tensor data Returns: None, but stores the save plan in `self.write_buckets` """ storage_plan: _StoragePrefix = plan.storage_data start = time() logger.debug(f"thread_count: {self.thread_count}, time: {start}") if self.separation_hint: assert ( self.thread_count > 1 ), "thread_count must be at least 2 if separation_hint is provided" bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count item_buckets = _split_by_size_and_type(bins, plan.items) logger.debug(f"bucket_prep, time: {time() - start}") start = time() # move tensors from GPU to CPU before starting async writing # We do D2H synchronously for now file_count = 0 def gen_file(prefix=""): nonlocal file_count file_name = f"{prefix}{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" file_count += 1 return file_name def _clone_if_needed(ten: torch.Tensor): """Clone if we detect incontiguous storage for CPU tensors Makes sure we perform a `clone` only if we detect incontiguous storage, so that we don't blow up host memory unnecessarily. TODO: For persistent worker, this work should be changed to move the cpu tensor to shared_memory. """ ten = ten.detach() if ten.device.type != "cpu": # We do D2H later when the async_request is scheduled for both sync / async # checkpointing return ten is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize return ten.clone() if is_view else ten # Prepare bytes / tensor data in each bucket, which will be assigned to each writer process self.write_buckets = [] for group_name, group_buckets in _split_by_separation_hint( item_buckets, self.separation_hint ).items(): for bucket in group_buckets: bytes_data = [ (item, planner.resolve_data(item)) for item in bucket if item.type == WriteItemType.BYTE_IO ] tensor_data = [ (item, _clone_if_needed(planner.resolve_data(item))) for item in bucket if item.type != WriteItemType.BYTE_IO ] if len(bytes_data) > 0 or len(tensor_data) > 0: file_name = gen_file(prefix=group_name) self.write_buckets.append( ( os.path.join(self.checkpoint_dir, file_name), file_name, (bytes_data, tensor_data), ) ) # Check if there is anything to write on this rank if len(self.write_buckets) > 0: assert len(self.write_buckets) <= self.thread_count, ( len(self.write_buckets), self.thread_count, ) self.results_queue = _get_write_results_queue() else: self.results_queue = None end = time() logger.debug(f"D2H and push, time: {end - start}")
[docs] def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Callable], List]: """ Get function that saves the data to storage along with its arguments. Allows the external caller to apply the save function synchronously or asynchronously. Returns: None (if there is nothing to write on this rank) or a tuple of: 1) the function that saves the data. 2) the function that stages the GPU tensors to a destination for async checkpointing. This function should be self-contained. 3) arguments to that function in 1). """ if not self.write_buckets: return None, None, [] if self.use_msc: import multistorageclient as msc open_file = msc.open else: open_file = self.open_file transform_list = [self.transforms] if hasattr(self, 'transforms') else [] # Select the appropriate write function based on IO mode write_func = ( self.write_preloaded_data_multiproc if self.is_multi_proc_io else self.write_preloaded_data_multithread_launcher ) return ( partial(write_func, transform_list, self.use_msc, open_file), partial(self.preload_tensors, self.write_buckets, True), [torch.distributed.get_rank(), self.write_buckets, self.results_queue], )
[docs] @staticmethod def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List[WriteBucket]: """ Preloads tensors in `state_dict` to host memory via CPU memory. Args: write_buckets (List): List of `WriteBucket` objects that define what to save in a checkpoint. non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True. """ result = [] for bucket in write_buckets: file_name, storage_key, (bytes_data, tensor_data) = bucket tensor_data = [ (item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data ] result.append((file_name, storage_key, (bytes_data, tensor_data))) if non_blocking: torch.cuda.synchronize() return result
@staticmethod def _initialize_write_execution(rank: int) -> Tuple[logging.Logger, float, dict]: """ Common initialization for write execution. Args: rank (int): training rank Returns: Tuple[logging.Logger, float, dict]: logger, start time, and initialized results dict """ logger = logging.getLogger(__name__) w_start = time() write_results_or_exc: Union[dict, Exception] = dict() return logger, w_start, write_results_or_exc @staticmethod def _build_worker_kwargs( worker_idx: int, write_bucket: WriteBucket, use_msc: bool, worker_type: str, **extra_kwargs ) -> dict: """ Build kwargs for worker (thread or process). Args: worker_idx (int): index of the worker write_bucket (WriteBucket): data to write use_msc (bool): flag to indicate use of multi storage client worker_type (str): 'thread' or 'proc' **extra_kwargs: additional worker-specific kwargs Returns: dict: kwargs for the worker """ idx_key = f'local_{worker_type}_idx' kwargs = { idx_key: worker_idx, 'write_bucket': write_bucket, 'use_fsync': True, } if use_msc: kwargs['use_msc'] = use_msc kwargs.update(extra_kwargs) return kwargs @staticmethod def _finalize_write_execution( global_results_queue: mp.Queue, write_results_or_exc: Union[dict, Exception], rank: int, w_start: float, worker_type: str, logger: logging.Logger, ) -> None: """ Common finalization for write execution. Args: global_results_queue (mp.Queue): queue to put results write_results_or_exc (Union[dict, Exception]): results or exception rank (int): training rank w_start (float): start time worker_type (str): 'MultiProc' or 'MultiThread' logger (logging.Logger): logger instance """ global_results_queue.put(write_results_or_exc) w_end = time() logger.debug( f"{worker_type} Background Async worker time to persist: {w_end - w_start} s for rank={rank}" )
[docs] @staticmethod @_disable_gc() def write_preloaded_data_multiproc( transform_list: List[_StorageWriterTransforms], use_msc: bool, open_file: Callable, rank: int, write_buckets: List[WriteBucket], global_results_queue: mp.Queue, ) -> None: """ Performs saving data to storage with multiple processes. Starts predefined number of processes and uses 2 queues to make sure the results are complete: - local_results_queue - to send the actual results - count_queue - small queue to mark worker as completed Using just one queue disallowed proper exception handling. This method is meant to be run in a forked subprocess. Triggering GC during execution leads to CUDA errors (cleaning up tensors owned by the parent process). To prevent this, we disable the GC explicitly for this function with _disable_gc. Args: write_buckets (List[WriteBucket]): write plan global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]] (or an Exception) from parallel write processes to the main training process Returns: None """ logger, w_start, write_results_or_exc = FileSystemWriterAsync._initialize_write_execution( rank ) ctx = mp.get_context('fork') local_results_queue = ctx.Queue() count_queue = ctx.JoinableQueue() p_list = [] for i, write_bucket in enumerate(write_buckets): try: current_process = mp.current_process() if current_process.daemon: err_msg = "Invalid Setup! User cannot establish a daemon Async worker and then use Multi-Proc File IO." logger.error(err_msg) raise RuntimeError(err_msg) count_queue.put(i) kwargs = FileSystemWriterAsync._build_worker_kwargs( worker_idx=i, write_bucket=write_bucket, use_msc=use_msc, worker_type='proc', results_queue=local_results_queue, count_queue=count_queue, ) p_list.append( ctx.Process( target=partial( FileSystemWriterAsync.write_preloaded_data, transform_list, open_file ), kwargs=kwargs, ) ) except Exception as e: err_msg = f'An error is caught while a proc {i} is created, error: {e}' logger.error(err_msg) write_results_or_exc = RuntimeError(err_msg) if not isinstance(write_results_or_exc, Exception): for p in p_list: p.start() logger.debug('FileSystemWriterAsync: collecting worker results...') # To make sure all nodes are completed count_queue.join() # At this point, all workers completed, so the queue should have exactly # `len(write_buckets)` items for proc_idx in range(len(write_buckets)): try: local_proc_idx, local_results_or_exc = local_results_queue.get() except queue.Empty: write_results_or_exc = RuntimeError( 'Unexpected empty `local_results_queue`' f' (got only {proc_idx}/{len(write_buckets)} items)' ) break else: if isinstance(local_results_or_exc, Exception): err_msg = ( f"Local process {local_proc_idx} encountered" f" an error: {local_results_or_exc}" ) logger.error(err_msg) write_results_or_exc = local_results_or_exc break assert isinstance(local_results_or_exc, list), type(local_results_or_exc) write_results_or_exc[local_proc_idx] = local_results_or_exc p_list[local_proc_idx].join() logger.debug('FileSystemWriterAsync: collected worker results successfully') FileSystemWriterAsync._finalize_write_execution( global_results_queue, write_results_or_exc, rank, w_start, "MultiProc", logger )
@staticmethod def _write_bucket_to_storage( transform_list: List[_StorageWriterTransforms], open_file: Callable, write_bucket: WriteBucket, use_fsync: bool, use_msc: bool, ) -> List[WriteResult]: """ Core logic for writing a bucket to storage. Args: transform_list (List[_StorageWriterTransforms]): streaming transforms list open_file (Callable): file open callable write_bucket (WriteBucket): data to write to storage use_fsync (bool): if True, calls os.fsync at the end of saving use_msc (bool): flag to indicate use of multi storage client Returns: List[WriteResult]: list of write results """ file_name, storage_key, (bytes_data, tensor_data) = write_bucket extra_kwargs = {} write_fn = _write_item if "serialization_format" in inspect.signature(_write_item).parameters: from torch.distributed.checkpoint.filesystem import SerializationFormat extra_kwargs['serialization_format'] = SerializationFormat.TORCH_SAVE if "transforms" in inspect.signature(_write_item).parameters: assert len(transform_list) <= 1 write_fn = partial(_write_item, *transform_list) local_results = [] with open_file(file_name, "wb") as stream: for write_item, data in bytes_data: local_results.append( write_fn(stream, data, write_item, storage_key, **extra_kwargs) ) for write_item, tensor in tensor_data: assert tensor.is_cpu local_results.append( write_fn(stream, tensor, write_item, storage_key, **extra_kwargs) ) if use_fsync: if use_msc: stream.fsync() else: os.fsync(stream.fileno()) return local_results
[docs] @staticmethod @_disable_gc() def write_preloaded_data( transform_list: List[_StorageWriterTransforms], open_file: Callable, local_proc_idx: int, write_bucket: WriteBucket, results_queue: mp.SimpleQueue, count_queue: mp.JoinableQueue, use_fsync: bool, **kwargs, ) -> None: """ Performs actual data saving to storage. Args: local_proc_idx (int): index of a local process that performs writing write_bucket (WriteBucket): data to write to storage results_queue (mp.Queue): queue to return the write results to the proxy checkpoint process. count_queue (mp.JoinableQueue): queue to marks worker task as completed use_fsync (bool): if True, calls os.fsync at the end of saving Returns: None, the write result are put into the `queue` """ logger = logging.getLogger(__name__) logger.debug(f'{local_proc_idx} started') mem_before = _process_memory() use_msc = kwargs.get('use_msc', False) try: local_results = FileSystemWriterAsync._write_bucket_to_storage( transform_list, open_file, write_bucket, use_fsync, use_msc ) local_output = (local_proc_idx, local_results) except Exception as e: logger.debug(f'{local_proc_idx} failed') local_output = (local_proc_idx, e) results_queue.put(local_output) # Signal this process is done. count_queue.get() count_queue.task_done() mem_after = _process_memory() logger.debug( f"{local_proc_idx} consumed: {mem_after - mem_before}," f" before: {mem_before}, after: {mem_after}" )
[docs] @staticmethod def write_preloaded_data_multithread_launcher( transform_list: List[_StorageWriterTransforms], use_msc: bool, open_file: Callable, rank: int, write_buckets: List[WriteBucket], global_results_queue: mp.Queue, ) -> None: """ Performs saving data to storage with multiple threads. Args: transform_list (List[_StorageWriterTransforms]): streaming transforms list use_msc (bool): flag to indicate use of multi storage client for storage access open_file (Callable): file open callable rank: training rank write_buckets (List[WriteBucket]): write plan global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]] (or an Exception) from parallel write processes to the main training process Returns: None """ logger, w_start, write_results_or_exc = FileSystemWriterAsync._initialize_write_execution( rank ) # Use ThreadPoolExecutor for efficient thread management with ThreadPoolExecutor(max_workers=len(write_buckets)) as executor: # submite write requests to thread executor future_to_bucket = {} for i, write_bucket in enumerate(write_buckets): try: kwargs = FileSystemWriterAsync._build_worker_kwargs( worker_idx=i, write_bucket=write_bucket, use_msc=use_msc, worker_type='thread', ) # Submit the task to the thread pool future = executor.submit( FileSystemWriterAsync.write_preloaded_data_thread, transform_list, open_file, **kwargs, ) future_to_bucket[future] = i except Exception as e: err_msg = f"An error is caught while a thread {i} is created, error: {e}" logger.error(err_msg) write_results_or_exc = RuntimeError(err_msg) if not isinstance(write_results_or_exc, Exception): logger.debug('FileSystemWriterAsync: collecting worker results...') try: for future in as_completed(future_to_bucket): bucket_idx = future_to_bucket[future] local_results = future.result() if not local_results: # The write results list is empty. This is unexpected behavior # as we expect every thread to have some write work err_msg = ( "Unexpected empty `local_results`" f"thread-id {bucket_idx} among {len(write_buckets)} " "did not have any items to write. Check split write buckets logic" ) logger.error(err_msg) write_results_or_exc = RuntimeError(err_msg) break if isinstance(local_results, Exception): err_msg = ( f"Thread-ID {bucket_idx} encountered an error: {local_results}" ) logger.error(err_msg) write_results_or_exc = local_results break assert isinstance(local_results, list), type(local_results) write_results_or_exc[bucket_idx] = local_results except Exception as e: err_msg = f"During async write, encountered an error: {e}" logger.error(err_msg) write_results_or_exc = e finally: # Cancel any futures that are still running # In case of errors or exception, we may have running futures due to early break for f in future_to_bucket: if not f.done(): f.cancel() # Shutdown the thread pool executor executor.shutdown(cancel_futures=True) logger.debug("FileSystemWriterAsync: collected worker results successfully") FileSystemWriterAsync._finalize_write_execution( global_results_queue, write_results_or_exc, rank, w_start, "MultiThread", logger )
[docs] @staticmethod def write_preloaded_data_thread( transform_list: List[_StorageWriterTransforms], open_file: Callable, local_thread_idx: int, write_bucket: WriteBucket, use_fsync: bool, **kwargs, ) -> Union[List[WriteResult], Exception]: """ Performs actual data saving to storage. Args: transform_list (List[_StorageWriterTransforms]): streaming transforms list open_file (Callable): file open callable local_thread_idx (int): index of a local thread that performs writing write_bucket (WriteBucket): data to write to storage use_fsync (bool): if True, calls os.fsync at the end of saving Returns: None, the write result are put into the `queue` """ logger = logging.getLogger(__name__) logger.debug(f'{local_thread_idx} started') mem_before = _process_memory() use_msc = kwargs.get('use_msc', False) try: local_results = FileSystemWriterAsync._write_bucket_to_storage( transform_list, open_file, write_bucket, use_fsync, use_msc ) except Exception as e: logger.debug(f'{local_thread_idx} failed with exception {e}') local_results = e mem_after = _process_memory() logger.debug( f"{local_thread_idx} consumed: {mem_after - mem_before}," f" before: {mem_before}, after: {mem_after}" ) return local_results
[docs] def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]: """Write all items from ``plan``.""" raise NotImplementedError('write_data not implemented for FileSystemWriterAsync')
[docs] def retrieve_write_results(self) -> Union[List[WriteResult], WRAPPED_EXCEPTION]: """ Turn the latest dict including write results from `self.results_queue` into a single results lists. Includes error check. Returns (Union(List[WriteResult], WRAPPED_EXCEPTION): the list of write results from all local processes performing the save, or a WRAPPED_EXCEPTION if an exception was raised during the writing process. """ assert self.write_buckets is not None if self.results_queue is None: write_results_or_exc = {} else: try: write_results_or_exc = self.results_queue.get_nowait() except queue.Empty: return _wrap_exception(RuntimeError('results_queue should not be empty')) if isinstance(write_results_or_exc, Exception): try: raise RuntimeError( f'Worker failure: {write_results_or_exc}' ) from write_results_or_exc except Exception as e: return _wrap_exception(e) write_results: dict = write_results_or_exc if len(write_results) != len(self.write_buckets): return _wrap_exception( RuntimeError( f'Incomplete worker results (expected {len(self.write_buckets)},' f' got {len(write_results)}. This probably indicates a worker failure.' ) ) return list(chain.from_iterable(write_results.values()))
[docs] def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: """Instead of assigning indices by plan order, uses PyT rank (same outcome). Args: local_plan (SavePlan): local plan to turn to a global plan (without interactions with other ranks) Returns: SavePlan - locally transformed plan equivalent to the plan that would be created by the coordinator """ return dataclasses.replace( local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_") )
[docs] def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: """ Finish the checkpointing process. Args: metadata (Metadata): metadata to save results (List[List[WriteResult]]): results to save """ if self.use_msc: import multistorageclient as msc storage_md = dict() for wr_list in results: storage_md.update({wr.index: wr.storage_data for wr in wr_list}) metadata.storage_data = storage_md # storage_meta was introduced since PyTorch 2.4 if "storage_meta" in inspect.signature(Metadata).parameters: metadata.storage_meta = self.storage_meta() path = os.path.join(self.checkpoint_dir, ".metadata") with msc.open(path, "wb") as metadata_file: # Issue: [B301:blacklist] Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue. # Severity: Medium Confidence: High # CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html) # More Info: https://bandit.readthedocs.io/en/1.8.3/blacklists/blacklist_calls.html#b301-pickle pickle.dump(metadata, metadata_file) # nosec else: super().finish(metadata, results)
[docs] def prepare_local_plan(self, plan: SavePlan) -> SavePlan: """ Prepare the local plan for the checkpointing process. """ if self.use_msc: import multistorageclient as msc msc.os.makedirs(str(self.checkpoint_dir), exist_ok=True) else: super().prepare_local_plan(plan) return plan
@property def checkpoint_id(self) -> Union[str, os.PathLike]: """ return the checkpoint_id that will be used to save the checkpoint. """ return str(self.checkpoint_dir)
[docs] @classmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: return True
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]: """ Splits write items according to item size into close to uniform bins. Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type, but with a fixed _item_size function. Args: bins (int): numbers of bins to split to items (List[WriteItem]): list of write items Returns (List[List[WriteItem]]): write items split to bins """ if bins == 1: return [items] bytes_items: List[WriteItem] = [] tensor_items: List[WriteItem] = [] for wi in items: container = bytes_items if wi.type == WriteItemType.BYTE_IO else tensor_items container.append(wi) buckets: List[List[WriteItem]] = [[] for _ in range(bins)] bucket_sizes = [0 for _ in range(bins)] # Assign bytes with a simple round-robin for i, item in enumerate(bytes_items): buckets[i % bins].append(item) # Sort tensor items by size in decreasing order once and store the size with item sized_tensors = [(item, _item_size(item)) for item in tensor_items] sized_tensors.sort(key=itemgetter(1), reverse=True) # Use a min heap for bin assignment # Store (total_size_of_bin, bin_index) tuples heap: List[Tuple[int, int]] = [(0, i) for i in range(bins)] # Assign tensors using heap for item, size in sized_tensors: total_bin_size, bin_idx = heappop(heap) buckets[bin_idx].append(item) heappush(heap, (total_bin_size + size, bin_idx)) return buckets def _split_by_separation_hint( buckets: List[List[WriteItem]], separation_hint: Optional[str] = None ) -> Dict[str, List[List[WriteItem]]]: """ Splits buckets into those whose keys begin with the separation_hint and those whose keys do not Args: buckets (List[List[WriteItem]]): buckets to split separation_hint (Optional[str]): optional prefix to split on Returns (Dict[str, List[List[WriteItem]]]): a dictionary mapping the prefix to the relevant buckets """ bins = len(buckets) buckets_with_separation_hint = {} if separation_hint is not None: buckets_default = [[] for _ in range(bins)] buckets_hint = [[] for _ in range(bins)] for i in range(bins): for item in buckets[i]: if item.index.fqn.startswith(separation_hint): buckets_hint[i].append(item) else: buckets_default[i].append(item) buckets_with_separation_hint[""] = buckets_default buckets_with_separation_hint[separation_hint] = buckets_hint else: buckets_with_separation_hint[""] = buckets return buckets_with_separation_hint def _item_size(item: WriteItem) -> int: """ Calculates size (in bytes) of a single write item. Same as torch.distributed.checkpoint.filesystem._item_size, but fixes computing chunk size (with item.tensor_data.chunk.sizes) Args: item (WriteItem): write item to compute the size of Returns (int): size of an item in bytes """ size = 1 assert item.tensor_data is not None # can't use math.prod as PT needs to support older python for s in item.tensor_data.chunk.sizes: size *= s dtype = item.tensor_data.properties.dtype return size * torch._utils._element_size(dtype) def _process_memory() -> int: """ Get memory used by current process. Returns (int): memory used by current process """ process = psutil.Process(os.getpid()) mem_info = process.memory_info() return mem_info.rss