FileSystemWriter example

  1import argparse
  2import logging
  3import shutil
  4
  5import multistorageclient as msc
  6import torch
  7import torch.distributed as dist
  8import torch.nn as nn
  9import torch.optim as optim
 10from torch.distributed import checkpoint
 11from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
 12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 13from torch.utils.data import DataLoader, DistributedSampler
 14
 15from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
 16from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
 17from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
 18    init_checkpoint_metadata_cache,
 19    save_state_dict_async_finalize,
 20    save_state_dict_async_plan,
 21)
 22
 23# Set up basic logging configuration
 24# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
 25logging.basicConfig(level=logging.INFO)
 26
 27FEAT_SIZE = 4096
 28DNN_OUT_SIZE = 128
 29BATCH_SIZE = 100
 30NUM_EPOCHS = 10
 31DATASET_LEN = 100000
 32CKPT_INTERVAL = 100
 33
 34
 35def print_on_rank0(msg):
 36    if dist.get_rank() == 0:
 37        print(msg)
 38
 39
 40def parse_args():
 41    parser = argparse.ArgumentParser(description="Async Checkpointing Example")
 42    parser.add_argument(
 43        "--ckpt_dir",
 44        default="/tmp/test_checkpointing/",
 45        help="Directory for saving checkpoints",
 46    )
 47    parser.add_argument(
 48        "--thread_count",
 49        default=2,
 50        type=int,
 51        help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
 52    )
 53    parser.add_argument(
 54        '--no_persistent_queue',
 55        action='store_false',
 56        default=True,
 57        dest='persistent_queue',
 58        help=(
 59            "Disables a persistent version of AsyncCallsQueue. "
 60            "Effective only when --async_save is set."
 61        ),
 62    )
 63    parser.add_argument(
 64        '--enable_msc',
 65        action='store_true',
 66        help="Enables MSC for checkpoint saving and loading. See Usage Guide in Async Checkpointing documentation for detailed instructions.",
 67    )
 68
 69    return parser.parse_args()
 70
 71
 72class SimpleDataset(torch.utils.data.Dataset):
 73    def __init__(self, size):
 74        self.size = size
 75
 76    def __len__(self):
 77        return self.size
 78
 79    def __getitem__(self, idx):
 80        x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
 81        y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
 82        return x, y
 83
 84
 85class SimpleModel(nn.Module):
 86    def __init__(self):
 87        super().__init__()
 88        self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
 89        self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
 90
 91    def forward(self, x):
 92        x = self.fc1(x)
 93        x = nn.functional.relu(x)
 94        x = self.fc2(x)
 95        return x
 96
 97
 98def init_distributed_backend(backend="nccl"):
 99    """Initializes the distributed process group using the specified backend."""
100    try:
101        dist.init_process_group(backend=backend, init_method="env://")
102        rank = dist.get_rank()
103        torch.cuda.set_device(dist.get_node_local_rank())
104        logging.info(f"Process {rank} initialized with {backend} backend.")
105        return rank, torch.distributed.get_world_size()
106    except Exception as e:
107        logging.error(f"Failed to initialize distributed backend: {e}")
108        raise
109
110
111def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
112    """Creates an async save request with a finalize function."""
113    save_fn, preload_fn, save_args = writer.get_save_function_and_args()
114
115    def finalize_fn():
116        """Finalizes async checkpointing and synchronizes processes."""
117        save_state_dict_async_finalize(*save_state_dict_ret)
118        dist.barrier()
119
120    return AsyncRequest(
121        save_fn, save_args, [finalize_fn], async_fn_kwargs={}, preload_fn=preload_fn
122    )
123
124
125def save_checkpoint(checkpoint_dir, async_queue, model, thread_count, enable_msc):
126    """Asynchronously saves a model checkpoint."""
127    state_dict = model.state_dict()
128    planner = DefaultSavePlanner()
129    writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count, use_msc=enable_msc)
130    coordinator_rank = 0
131
132    save_state_dict_ret = save_state_dict_async_plan(
133        state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
134    )
135    save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
136    async_queue.schedule_async_request(save_request)
137
138
139def load_checkpoint(checkpoint_dir, model, thread_count, enable_msc):
140    """Loads a model checkpoint synchronously."""
141    state_dict = model.state_dict()
142    if enable_msc:
143        reader = msc.torch.MultiStorageFileSystemReader(checkpoint_dir, thread_count=thread_count)
144    else:
145        reader = FileSystemReader(checkpoint_dir)
146    checkpoint.load(
147        state_dict=state_dict,
148        storage_reader=reader,
149        planner=DefaultLoadPlanner(),
150    )
151    return state_dict
152
153
154def main():
155    args = parse_args()
156    logging.info(f"Arguments: {args}")
157
158    # Initialize distributed training
159    rank, world_size = init_distributed_backend(backend="nccl")
160
161    # Define checkpoint directory based on iteration number
162    dataset = SimpleDataset(size=DATASET_LEN)
163    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
164    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
165
166    # Model, optimizer, and FSDP wrapper
167    model = SimpleModel().to("cuda")
168    fsdp_model = FSDP(model)
169    optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
170    loss_fn = nn.MSELoss()
171
172    # Create an async queue for handling asynchronous operations
173    async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
174
175    iteration = 0
176    num_iters_in_epoch = len(dataloader)
177    print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
178
179    num_iters_for_10pct = num_iters_in_epoch // 10  # iters for 1/10 of epoch
180    checkpoint_dir = None
181    sampler.set_epoch(0)
182
183    init_checkpoint_metadata_cache()
184
185    for batch_idx, (data, target) in enumerate(dataloader):
186        async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
187        if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
188            print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
189        optimizer.zero_grad()
190        output = fsdp_model(data)
191        loss = loss_fn(output, target)
192        loss.backward()
193        optimizer.step()
194        if batch_idx % num_iters_for_10pct == 0:
195            iteration = batch_idx
196            checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
197            # Save the model asynchronously
198            save_checkpoint(
199                checkpoint_dir, async_queue, fsdp_model, args.thread_count, args.enable_msc
200            )
201            print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
202            iteration += batch_idx
203    print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
204
205    logging.info("Finalizing checkpoint save...")
206    async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
207    async_queue.close()  # Explicitly close queue
208
209    # Synchronize processes before loading
210    dist.barrier()
211    print_on_rank0(f"loading from {checkpoint_dir}")
212    # Load the checkpoint
213    loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model, args.thread_count, args.enable_msc)
214
215    # Synchronize again to ensure all ranks have completed loading
216    dist.barrier()
217
218    # Clean up checkpoint directory (only on rank 0)
219    if dist.get_rank() == 0:
220        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
221        if args.enable_msc:
222            msc.delete(args.ckpt_dir, recursive=True)
223        else:
224            shutil.rmtree(args.ckpt_dir)
225
226    # Ensure NCCL process group is properly destroyed
227    if dist.is_initialized():
228        dist.destroy_process_group()
229
230
231if __name__ == "__main__":
232    main()