Source code for multistorageclient.contrib.torch.core
1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os
17from typing import IO, Any, Union
18
19import torch as _torch
20
21from ...pathlib import MultiStoragePath
22from ...shortcuts import open as msc_open
23
24
[docs]
25def load(f: Union[str, os.PathLike[str], IO[bytes]], *args: Any, **kwargs: Any) -> Any:
26 """
27 Adapt ``torch.load``.
28 """
29 if isinstance(f, str):
30 with msc_open(f, "rb", prefetch_file=True) as fp:
31 return _torch.load(fp, *args, **kwargs)
32 elif isinstance(f, MultiStoragePath):
33 with f.open("rb") as fp:
34 return _torch.load(fp, *args, **kwargs)
35 else:
36 return _torch.load(f, *args, **kwargs)
37
38
[docs]
39def save(obj: object, f: Union[str, os.PathLike[str], IO[bytes]], *args: Any, **kwargs: Any) -> Any:
40 """
41 Adapt ``torch.save``.
42 """
43 attributes_dict = {}
44 if "attributes" in kwargs:
45 attributes_dict = kwargs.pop("attributes")
46 if isinstance(f, str):
47 with msc_open(f, "wb", attributes=attributes_dict) as fp:
48 return _torch.save(obj, fp, *args, **kwargs)
49 elif isinstance(f, MultiStoragePath):
50 with f.open("wb", attributes=attributes_dict) as fp:
51 return _torch.save(obj, fp, *args, **kwargs)
52 else:
53 return _torch.save(obj, f, *args, **kwargs)