1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2# SPDX-License-Identifier: Apache-2.0 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16fromtypingimportAny,Union 17 18importnumpyas_np 19 20from..pathlibimportMultiStoragePath 21from..shortcutsimportopenasmsc_open 22 23
[docs] 47defload(*args:Any,**kwargs:Any)->Union[_np.ndarray,dict[str,_np.ndarray],_np.lib.npyio.NpzFile]: 48""" 49 Adapt ``numpy.load``. 50 """ 51 52file=args[0]ifargselsekwargs.get("file") 53 54defresolve_filesystem_path(file:Union[str,MultiStoragePath])->Union[str,None]: 55"""Helper function to get the local path from a filepath or MultiStoragePath.""" 56ifisinstance(file,str): 57file=MultiStoragePath(file) 58returnfile.as_posix() 59 60ifisinstance(file,str)orisinstance(file,MultiStoragePath): 61# For .npy with memmap mode != none, _np.load() will call format.open_memmap() underneath, 62# Which require a file path string 63# Refs: 64# https://github.com/numpy/numpy/blob/main/numpy/lib/_npyio_impl.py#L477 65# https://numpy.org/doc/stable/reference/generated/numpy.lib.format.open_memmap.html 66# 67# For the simplicity of the code, we always pass the file path to _np.load and let it convert the path 68# to file-like object. 69 70# block until download is completed to ensure local path is available for the open() call within _np.load() 71local_path=resolve_filesystem_path(file) 72ifnotlocal_path: 73raiseValueError(f"local_path={local_path} for the downloaded file[{file}] is not valid") 74ifargs: 75args=(local_path,)+args[1:] 76else: 77kwargs["file"]=local_path 78 79return_np.load(*args,**kwargs)# pyright: ignore [reportArgumentType, reportCallIssue]
80 81
[docs] 82defsave(*args:Any,**kwargs:Any)->None: 83""" 84 Adapt ``numpy.save``. 85 """ 86 87file=args[0]ifargselsekwargs.get("file") 88ifisinstance(file,str): 89# use context manager to make sure to upload the file once close() is called 90withmsc_open(file,mode="wb")asfp: 91ifargs: 92args=(fp,)+args[1:] 93else: 94kwargs["file"]=fp 95 96_np.save(*args,**kwargs)# pyright: ignore [reportArgumentType, reportCallIssue] 97elifisinstance(file,MultiStoragePath): 98withfile.open(mode="wb")asfp: 99ifargs:100args=(fp,)+args[1:]101else:102kwargs["file"]=fp103104_np.save(*args,**kwargs)# pyright: ignore [reportArgumentType, reportCallIssue]105else:106_np.save(*args,**kwargs)# pyright: ignore [reportArgumentType, reportCallIssue]