Asynchronous Pytorch Distributed Checkpoint save with optimized FileSystemWriter

State dict saver for PyT Distributed format allowing asynchronous save.

class nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.CheckpointMetadataCache[source]

Bases: object

Cache of metadata for checkpoint saving.

This class maintains a cache of metadata used during distributed checkpoint saving operations. It stores various components of the save plan and metadata to optimize subsequent checkpoint saves by avoiding redundant planning and metadata generation when the checkpoint structure remains consistent across iterations.

The cache stores: - cached_central_plan: The aggregated global save plan from all ranks - cached_local_plan: The local save plan describing how the local state_dict is written - cached_global_metadata: The global metadata (only held by the coordinator rank) - validated_cache_reuse: Flag indicating if checkpoint structures are consistent - validated_loaded_metadata_reuse: Flag indicating the metadata loaded from the prev checkpoint

is validated to reuse, which skips all metadata communications

  • loaded_all_plans: Cached local plans from the previous checkpoint’s metadata file

This caching mechanism helps optimize checkpoint saving by: 1. Avoiding redundant planning when checkpoint structures are consistent 2. Reusing global metadata when possible 3. Enabling decentralized planning when supported by the planner and storage writer

get_cache_metadata()[source]

Retrieves the cached metadata components.

This method returns a tuple containing the cached central plan, local plan, cache reuse validation, and all local plans from the previous checkpoint’s metadata file.

Return type:

Tuple[SavePlan, SavePlan, bool, List[SavePlan]] | None

get_metadata_caching_status()[source]

Retrieves the current caching status

This function returns the current caching status of the checkpoint metadata

prepare_save_state_dict_ret(rank, coordinator, save_state_dict_ret)[source]

Prepares the save state dict return value based on the cached metadata.

This method checks if the global metadata can be reused from the previous checkpoint. If so, it updates the save state dict return value with the cached global metadata.

Parameters:
  • rank (int) – The rank of the current process

  • coordinator (int) – The coordinator rank

  • save_state_dict_ret (Tuple[FileSystemWriterAsync, Union[Metadata, None]]) – The return value of the save state dict

Returns:

The updated save state dict return value with the cached global metadata if it can be reused.

Return type:

Tuple[FileSystemWriterAsync, Union[Metadata, None]]

set_cache_metadata(central_plan, local_plan, global_md_verify_reuse)[source]

Sets the cached metadata and updates the cache flags.

This method updates the cache with the latest central plan, local plan, and metadata reuse validation results. It also checks if the central plan is consistent with the cached plan.

Parameters:
  • central_plan (SavePlan) – The latest central plan

  • local_plan (SavePlan) – The latest local plan

  • global_md_verify_reuse (bool) – Flag indicating if global metadata reuse is valid

set_cached_global_metadata(cached_global_metadata)[source]

Sets the cached global metadata and extracts local plans from it.

This method stores the global metadata from a previous checkpoint and attempts to extract the local plans from it. The local plans are used to verify if the global metadata can be reused in subsequent checkpoint saves.

Parameters:

cached_global_metadata (Metadata) – The global metadata from a previous checkpoint that contains information about the checkpoint structure and local plans.

Note

If the metadata does not contain local plans, a debug message is logged indicating that global metadata reuse verification will not be possible.

nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.get_metadata_caching_status()[source]

Retrieves the current caching status

This function returns the current caching status of the checkpoint metadata

nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.init_checkpoint_metadata_cache(cached_global_metadata)[source]

Initializes the checkpoint metadata cache.

This function creates a new CheckpointMetadataCache instance and sets the cached global metadata from the previous checkpoint

Parameters:

cached_global_metadata (Metadata)

nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.save_state_dict_async_finalize(storage_writer, global_metadata, dist_wrapper)[source]

Finalization of save_state_dict_async_plan.

The input arguments are the same as the save_state_dict_async_plan output, the write_results are retrieved from the storage_writer.

Parameters:
  • storage_writer (FileSystemWriterAsync) – storage writer used for planning

  • global_metadata (Metadata) – metadata created during planning

  • dist_wrapper (_DistWrapper) – distributed wrapper created during planning

Return type:

None

Returns: None

nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.save_state_dict_async_plan(state_dict, storage_writer, process_group=None, coordinator_rank=0, planner=None, enable_cache=False, metadata_cache=None)[source]

First stage of saving a state dict to storage.

This is an async adjustment of torch.distributed.checkpoint.state_dict_saver. In order to support async save, saving should be split into three parts: 1. Planning 2. Actual saving 3. Finalization

Out of these, step (2) must happen asynchronously. The first step is realized with this function.

The planning part consists of several steps, described here: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner

Parameters:
  • state_dict (STATE_DICT_TYPE) – state dict to save

  • storage_writer (FileSystemWriterAsync) – in current version only an instance of FileSystemWriterAsync

  • process_group (dist.ProcessGroup, optional) – process group used for save planning

  • coordinator_rank (int, optional) – coordinator rank for planning. Defaults to 0.

  • planner (SavePlanner, optional) – save planner for torch.distributed.checkpoint format

  • cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional) –

    Each object of this tuple will be used in the order as following cached_central_plan (SavePlan): a globally coordinated save plan

    cached in the previous iteration

    cached_local_plan (SavePlan): a local plan

    cached in the previous iteration

    validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict

    is consistent over iterations

  • enable_cache (bool)

  • metadata_cache (CheckpointMetadataCache | None)

Return type:

Tuple[FileSystemWriterAsync, Metadata | None, _DistWrapper]

Returns: Tuple of:
  • storage writer (the one passed as input)

  • metadata from planning (or None if we reuse cached global metadata)

  • distributed wrapper used for planning

The return value of this function should be passed as an input to save_state_dict_async_finalize and cached_plan to skip reduce_scatter at planning.

nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.verify_global_md_reuse(loaded_all_plans, local_plan, rank, dist_wrapper)[source]
Verifies that global metadata reuse is possible by checking the loaded plans from the

checkpoint are consistent, which means we have the same settings when resuming training.

Parameters:
  • loaded_all_plans (List[SavePlan]) – List[SavePlan], The loaded plans from the checkpoint (stored in checkpoint metadata).

  • local_plan (SavePlan) – SavePlan, The local save plan.

  • rank (int) – Current process rank.

  • dist_wrapper (_DistWrapper) – distributed wrapper created during planning

Return type:

bool

Returns: True iff the global metadata reuse is possible.