FileSystemWriter example

  1import argparse
  2import logging
  3import shutil
  4
  5import multistorageclient as msc
  6import torch
  7import torch.distributed as dist
  8import torch.nn as nn
  9import torch.optim as optim
 10from torch.distributed import checkpoint
 11from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
 12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 13from torch.utils.data import DataLoader, DistributedSampler
 14
 15from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
 16from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
 17from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
 18    init_checkpoint_metadata_cache,
 19    save_state_dict_async_finalize,
 20    save_state_dict_async_plan,
 21)
 22
 23# Set up basic logging configuration
 24# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
 25logging.basicConfig(level=logging.INFO)
 26
 27FEAT_SIZE = 4096
 28DNN_OUT_SIZE = 128
 29BATCH_SIZE = 100
 30NUM_EPOCHS = 10
 31DATASET_LEN = 100000
 32CKPT_INTERVAL = 100
 33
 34
 35def print_on_rank0(msg):
 36    if dist.get_rank() == 0:
 37        print(msg)
 38
 39
 40def parse_args():
 41    parser = argparse.ArgumentParser(description="Async Checkpointing Example")
 42    parser.add_argument(
 43        "--ckpt_dir",
 44        default="/tmp/test_checkpointing/",
 45        help="Directory for saving checkpoints",
 46    )
 47    parser.add_argument(
 48        "--thread_count",
 49        default=2,
 50        type=int,
 51        help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
 52    )
 53    parser.add_argument(
 54        '--no_persistent_queue',
 55        action='store_false',
 56        default=True,
 57        dest='persistent_queue',
 58        help=(
 59            "Disables a persistent version of AsyncCallsQueue. "
 60            "Effective only when --async_save is set."
 61        ),
 62    )
 63    parser.add_argument(
 64        '--enable_msc',
 65        action='store_true',
 66        help="Enables MSC for checkpoint saving and loading. See Usage Guide in Async Checkpointing documentation for detailed instructions.",
 67    )
 68
 69    return parser.parse_args()
 70
 71
 72class SimpleDataset(torch.utils.data.Dataset):
 73    def __init__(self, size):
 74        self.size = size
 75
 76    def __len__(self):
 77        return self.size
 78
 79    def __getitem__(self, idx):
 80        x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
 81        y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
 82        return x, y
 83
 84
 85class SimpleModel(nn.Module):
 86    def __init__(self):
 87        super().__init__()
 88        self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
 89        self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
 90
 91    def forward(self, x):
 92        x = self.fc1(x)
 93        x = nn.functional.relu(x)
 94        x = self.fc2(x)
 95        return x
 96
 97
 98def init_distributed_backend(backend="nccl"):
 99    """Initializes the distributed process group using the specified backend."""
100    try:
101        dist.init_process_group(backend=backend, init_method="env://")
102        rank = dist.get_rank()
103        torch.cuda.set_device(dist.get_node_local_rank())
104        logging.info(f"Process {rank} initialized with {backend} backend.")
105        return rank, torch.distributed.get_world_size()
106    except Exception as e:
107        logging.error(f"Failed to initialize distributed backend: {e}")
108        raise
109
110
111def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
112    """Creates an async save request with a finalize function."""
113    save_fn, preload_fn, save_args = writer.get_save_function_and_args()
114
115    def finalize_fn():
116        """Finalizes async checkpointing and synchronizes processes."""
117        save_state_dict_async_finalize(*save_state_dict_ret)
118        dist.barrier()
119
120    return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
121
122
123def save_checkpoint(checkpoint_dir, async_queue, model, thread_count, enable_msc):
124    """Asynchronously saves a model checkpoint."""
125    state_dict = model.state_dict()
126    planner = DefaultSavePlanner()
127    writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count, use_msc=enable_msc)
128    coordinator_rank = 0
129
130    save_state_dict_ret = save_state_dict_async_plan(
131        state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
132    )
133    save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
134    async_queue.schedule_async_request(save_request)
135
136
137def load_checkpoint(checkpoint_dir, model, thread_count, enable_msc):
138    """Loads a model checkpoint synchronously."""
139    state_dict = model.state_dict()
140    if enable_msc:
141        reader = msc.torch.MultiStorageFileSystemReader(checkpoint_dir, thread_count=thread_count)
142    else:
143        reader = FileSystemReader(checkpoint_dir)
144    checkpoint.load(
145        state_dict=state_dict,
146        storage_reader=reader,
147        planner=DefaultLoadPlanner(),
148    )
149    return state_dict
150
151
152def main():
153    args = parse_args()
154    logging.info(f"Arguments: {args}")
155
156    # Initialize distributed training
157    rank, world_size = init_distributed_backend(backend="nccl")
158
159    # Define checkpoint directory based on iteration number
160    dataset = SimpleDataset(size=DATASET_LEN)
161    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
162    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
163
164    # Model, optimizer, and FSDP wrapper
165    model = SimpleModel().to("cuda")
166    fsdp_model = FSDP(model)
167    optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
168    loss_fn = nn.MSELoss()
169
170    # Create an async queue for handling asynchronous operations
171    async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
172
173    iteration = 0
174    num_iters_in_epoch = len(dataloader)
175    print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
176
177    num_iters_for_10pct = num_iters_in_epoch // 10  # iters for 1/10 of epoch
178    checkpoint_dir = None
179    sampler.set_epoch(0)
180
181    init_checkpoint_metadata_cache()
182
183    for batch_idx, (data, target) in enumerate(dataloader):
184        async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
185        if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
186            print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
187        optimizer.zero_grad()
188        output = fsdp_model(data)
189        loss = loss_fn(output, target)
190        loss.backward()
191        optimizer.step()
192        if batch_idx % num_iters_for_10pct == 0:
193            iteration = batch_idx
194            checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
195            # Save the model asynchronously
196            save_checkpoint(
197                checkpoint_dir, async_queue, fsdp_model, args.thread_count, args.enable_msc
198            )
199            print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
200            iteration += batch_idx
201    print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
202
203    logging.info("Finalizing checkpoint save...")
204    async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
205    async_queue.close()  # Explicitly close queue
206
207    # Synchronize processes before loading
208    dist.barrier()
209    print_on_rank0(f"loading from {checkpoint_dir}")
210    # Load the checkpoint
211    loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model, args.thread_count, args.enable_msc)
212
213    # Synchronize again to ensure all ranks have completed loading
214    dist.barrier()
215
216    # Clean up checkpoint directory (only on rank 0)
217    if dist.get_rank() == 0:
218        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
219        if args.enable_msc:
220            msc.delete(args.ckpt_dir, recursive=True)
221        else:
222            shutil.rmtree(args.ckpt_dir)
223
224    # Ensure NCCL process group is properly destroyed
225    if dist.is_initialized():
226        dist.destroy_process_group()
227
228
229if __name__ == "__main__":
230    main()