Asynchronous Pytorch Distributed Checkpoint save with optimized FileSystemWriter
State dict saver for PyT Distributed format allowing asynchronous save.
- nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.save_state_dict_async_finalize(storage_writer, global_metadata, dist_wrapper)[source]
Finalization of save_state_dict_async_plan.
The input arguments are the same as the save_state_dict_async_plan output, the write_results are retrieved from the storage_writer.
- Parameters:
storage_writer (FileSystemWriterAsync) – storage writer used for planning
global_metadata (Metadata) – metadata created during planning
dist_wrapper (_DistWrapper) – distributed wrapper created during planning
- Return type:
None
Returns: None
- nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.save_state_dict_async_plan(state_dict, storage_writer, process_group=None, coordinator_rank=0, planner=None, cached_ckpt_structure=None, loaded_all_plans=None)[source]
First stage of asynchronously saving a state dict to storage.
This is an async adaptation of torch.distributed.checkpoint.state_dict_saver. To support async save, the process is split into three stages:
Planning
Actual saving (must be asynchronous)
Finalization
The planning step is handled by this function and follows several steps as described in the [PyTorch documentation](https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner).
- Parameters:
state_dict (STATE_DICT_TYPE) – The state dict to save.
storage_writer (FileSystemWriterAsync) – The storage writer. Currently, only an instance of FileSystemWriterAsync is supported.
process_group (dist.ProcessGroup, optional) – The process group used for save planning.
coordinator_rank (int, optional) – The coordinator rank for planning. Defaults to 0.
planner (SavePlanner, optional) – The save planner for the torch.distributed.checkpoint format.
cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], optional) – A tuple containing: - cached_central_plan (SavePlan): A globally coordinated save plan cached in the previous iteration. - cached_local_plan (SavePlan): A local plan cached in the previous iteration. - validated_cache_reuse (bool): Whether global metadata and the planning dict are consistent over iterations.
- Returns:
The storage writer (same as input).
Metadata from planning (or None if cached global metadata is reused).
The distributed wrapper used for planning.
- Return type:
Tuple
The return value of this function should be passed as input to save_state_dict_async_finalize, along with cached_plan, to skip reduce_scatter during planning.
- nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver.verify_global_md_reuse(loaded_all_plans, local_plan, rank, dist_wrapper)[source]
- Verifies that global metadata reuse is possible by checking the loaded plans from the
checkpoint are consistent, which means we have the same settings when resuming training.
- Parameters:
- Return type:
Returns: True iff the global metadata reuse is possible.