Asynchronous Pytorch Distributed Checkpoint save with optimized FileSystemWriter

State dict saver for PyT Distributed format allowing asynchronous save.

nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.save_state_dict_async_finalize(storage_writer, global_metadata, dist_wrapper)[source]

Finalization of save_state_dict_async_plan.

The input arguments are the same as the save_state_dict_async_plan output, the write_results are retrieved from the storage_writer.

Parameters:
  • storage_writer (FileSystemWriterAsync) – storage writer used for planning

  • global_metadata (Metadata) – metadata created during planning

  • dist_wrapper (_DistWrapper) – distributed wrapper created during planning

Return type:

None

Returns: None

nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.save_state_dict_async_plan(state_dict, storage_writer, process_group=None, coordinator_rank=0, planner=None, cached_ckpt_structure=None, loaded_all_plans=None)[source]

First stage of asynchronously saving a state dict to storage.

This is an async adaptation of torch.distributed.checkpoint.state_dict_saver. To support async save, the process is split into three stages:

  1. Planning

  2. Actual saving (must be asynchronous)

  3. Finalization

The planning step is handled by this function and follows several steps as described in the [PyTorch documentation](https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner).

Parameters:
  • state_dict (STATE_DICT_TYPE) – The state dict to save.

  • storage_writer (FileSystemWriterAsync) – The storage writer. Currently, only an instance of FileSystemWriterAsync is supported.

  • process_group (dist.ProcessGroup, optional) – The process group used for save planning.

  • coordinator_rank (int, optional) – The coordinator rank for planning. Defaults to 0.

  • planner (SavePlanner, optional) – The save planner for the torch.distributed.checkpoint format.

  • cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], optional) – A tuple containing: - cached_central_plan (SavePlan): A globally coordinated save plan cached in the previous iteration. - cached_local_plan (SavePlan): A local plan cached in the previous iteration. - validated_cache_reuse (bool): Whether global metadata and the planning dict are consistent over iterations.

  • loaded_all_plans (List[SavePlan] | None)

Returns:

  • The storage writer (same as input).

  • Metadata from planning (or None if cached global metadata is reused).

  • The distributed wrapper used for planning.

Return type:

Tuple

The return value of this function should be passed as input to save_state_dict_async_finalize, along with cached_plan, to skip reduce_scatter during planning.

nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.verify_global_md_reuse(loaded_all_plans, local_plan, rank, dist_wrapper)[source]
Verifies that global metadata reuse is possible by checking the loaded plans from the

checkpoint are consistent, which means we have the same settings when resuming training.

Parameters:
  • loaded_all_plans (List[SavePlan]) – List[SavePlan], The loaded plans from the checkpoint (stored in checkpoint metadata).

  • local_plan (SavePlan) – SavePlan, The local save plan.

  • rank (int) – Current process rank.

  • dist_wrapper (_DistWrapper) – distributed wrapper created during planning

Return type:

bool

Returns: True iff the global metadata reuse is possible.