# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TorchAsyncCheckpoint defines a wrapper for the async version of `torch.save` with
an additional method to synchronize async saving requests
"""
import logging
import torch
from ..utils import wrap_for_async, preload_tensors
from .core import AsyncCallsQueue, AsyncRequest
logger = logging.getLogger(__name__)
[docs]
class TorchAsyncCheckpoint(object):
async_fn = None
def __init__(self):
self.save = torch.save
self._async_calls_queue = AsyncCallsQueue()
TorchAsyncCheckpoint.async_fn = wrap_for_async(torch.save)
[docs]
def async_save(self, state_dict, *args, **kwargs):
"""
Keeps the original interface of `torch.save`
Schedules a `AsyncReuqest` with preloading tensors to CPU with pinned memcpy
"""
preloaded_sd = preload_tensors(state_dict)
torch.cuda.synchronize()
async_request = AsyncRequest(TorchAsyncCheckpoint.async_fn, (preloaded_sd, *args), [], kwargs)
self._async_calls_queue.schedule_async_request(async_request)
[docs]
def finalize_async_save(self, blocking: bool=False, no_dist=True):
""" Finalizes active async save calls.
Args:
blocking (bool, optional): if True, will wait until all active requests
are done. Otherwise, finalizes only the async request that already
finished. Defaults to False.
"""
if blocking and self._async_calls_queue.get_num_unfinalized_calls() > 0:
if torch.distributed.get_rank() == 0:
logger.info('Unfinalized async checkpoint saves. Finalizing them synchronously now.')
self._async_calls_queue.maybe_finalize_async_calls(blocking, no_dist=no_dist)