Asynchronous Pytorch Distributed Checkpoint save with optimized FileSystemWriter
State dict saver for PyT Distributed format allowing asynchronous save.
- class nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.CheckpointMetadataCache[source]
Bases:
object
Cache of metadata for checkpoint saving.
This class maintains a cache of metadata used during distributed checkpoint saving operations. It stores various components of the save plan and metadata to optimize subsequent checkpoint saves by avoiding redundant planning and metadata generation when the checkpoint structure remains consistent across iterations.
The cache stores: - cached_central_plan: The aggregated global save plan from all ranks - cached_local_plan: The local save plan describing how the local state_dict is written - cached_global_metadata: The global metadata (only held by the coordinator rank) - validated_cache_reuse: Flag indicating if checkpoint structures are consistent - validated_loaded_metadata_reuse: Flag indicating the metadata loaded from the prev checkpoint
is validated to reuse, which skips all metadata communications
loaded_all_plans: Cached local plans from the previous checkpoint’s metadata file
This caching mechanism helps optimize checkpoint saving by: 1. Avoiding redundant planning when checkpoint structures are consistent 2. Reusing global metadata when possible 3. Enabling decentralized planning when supported by the planner and storage writer
- get_cache_metadata()[source]
Retrieves the cached metadata components.
This method returns a tuple containing the cached central plan, local plan, cache reuse validation, and all local plans from the previous checkpoint’s metadata file.
- get_metadata_caching_status()[source]
Retrieves the current caching status
This function returns the current caching status of the checkpoint metadata
- prepare_save_state_dict_ret(rank, coordinator, save_state_dict_ret)[source]
Prepares the save state dict return value based on the cached metadata.
This method checks if the global metadata can be reused from the previous checkpoint. If so, it updates the save state dict return value with the cached global metadata.
- Parameters:
rank (int) – The rank of the current process
coordinator (int) – The coordinator rank
save_state_dict_ret (Tuple[FileSystemWriterAsync, Union[Metadata, None]]) – The return value of the save state dict
- Returns:
The updated save state dict return value with the cached global metadata if it can be reused.
- Return type:
Tuple[FileSystemWriterAsync, Union[Metadata, None]]
- set_cache_metadata(central_plan, local_plan, global_md_verify_reuse)[source]
Sets the cached metadata and updates the cache flags.
This method updates the cache with the latest central plan, local plan, and metadata reuse validation results. It also checks if the central plan is consistent with the cached plan.
- Parameters:
central_plan (SavePlan) – The latest central plan
local_plan (SavePlan) – The latest local plan
global_md_verify_reuse (bool) – Flag indicating if global metadata reuse is valid
- set_cached_global_metadata(cached_global_metadata)[source]
Sets the cached global metadata and extracts local plans from it.
This method stores the global metadata from a previous checkpoint and attempts to extract the local plans from it. The local plans are used to verify if the global metadata can be reused in subsequent checkpoint saves.
- Parameters:
cached_global_metadata (Metadata) – The global metadata from a previous checkpoint that contains information about the checkpoint structure and local plans.
Note
If the metadata does not contain local plans, a debug message is logged indicating that global metadata reuse verification will not be possible.
- nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.get_metadata_caching_status()[source]
Retrieves the current caching status
This function returns the current caching status of the checkpoint metadata
- nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.init_checkpoint_metadata_cache(cached_global_metadata)[source]
Initializes the checkpoint metadata cache.
This function creates a new CheckpointMetadataCache instance and sets the cached global metadata from the previous checkpoint
- Parameters:
cached_global_metadata (Metadata)
- nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.save_state_dict_async_finalize(storage_writer, global_metadata, dist_wrapper)[source]
Finalization of save_state_dict_async_plan.
The input arguments are the same as the save_state_dict_async_plan output, the write_results are retrieved from the storage_writer.
- Parameters:
storage_writer (FileSystemWriterAsync) – storage writer used for planning
global_metadata (Metadata) – metadata created during planning
dist_wrapper (_DistWrapper) – distributed wrapper created during planning
- Return type:
None
Returns: None
- nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.save_state_dict_async_plan(state_dict, storage_writer, process_group=None, coordinator_rank=0, planner=None, enable_cache=False, metadata_cache=None)[source]
First stage of saving a state dict to storage.
This is an async adjustment of torch.distributed.checkpoint.state_dict_saver. In order to support async save, saving should be split into three parts: 1. Planning 2. Actual saving 3. Finalization
Out of these, step (2) must happen asynchronously. The first step is realized with this function.
The planning part consists of several steps, described here: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner
- Parameters:
state_dict (STATE_DICT_TYPE) – state dict to save
storage_writer (FileSystemWriterAsync) – in current version only an instance of FileSystemWriterAsync
process_group (dist.ProcessGroup, optional) – process group used for save planning
coordinator_rank (int, optional) – coordinator rank for planning. Defaults to 0.
planner (SavePlanner, optional) – save planner for torch.distributed.checkpoint format
cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional) –
Each object of this tuple will be used in the order as following cached_central_plan (SavePlan): a globally coordinated save plan
cached in the previous iteration
- cached_local_plan (SavePlan): a local plan
cached in the previous iteration
- validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict
is consistent over iterations
enable_cache (bool)
metadata_cache (CheckpointMetadataCache | None)
- Return type:
Tuple[FileSystemWriterAsync, Metadata | None, _DistWrapper]
- Returns: Tuple of:
storage writer (the one passed as input)
metadata from planning (or None if we reuse cached global metadata)
distributed wrapper used for planning
The return value of this function should be passed as an input to save_state_dict_async_finalize and cached_plan to skip reduce_scatter at planning.
- nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.verify_global_md_reuse(loaded_all_plans, local_plan, rank, dist_wrapper)[source]
- Verifies that global metadata reuse is possible by checking the loaded plans from the
checkpoint are consistent, which means we have the same settings when resuming training.
- Parameters:
- Return type:
Returns: True iff the global metadata reuse is possible.