Source code for multistorageclient.contrib.numpy

 1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 2# SPDX-License-Identifier: Apache-2.0
 3#
 4# Licensed under the Apache License, Version 2.0 (the "License");
 5# you may not use this file except in compliance with the License.
 6# You may obtain a copy of the License at
 7#
 8# http://www.apache.org/licenses/LICENSE-2.0
 9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from typing import Any, Dict, Union
17
18import numpy as _np
19
20from ..shortcuts import open as msc_open
21from ..types import MSC_PROTOCOL
22
23
[docs] 24def memmap(*args: Any, **kwargs: Any) -> _np.memmap: 25 """ 26 Adapt ``numpy.memmap``. 27 """ 28 29 if not args: 30 raise TypeError("missing filename argument") 31 file = args[0] 32 33 if isinstance(file, str) and file.startswith(MSC_PROTOCOL): 34 if "mode" not in kwargs: 35 kwargs["mode"] = "r" 36 with msc_open(file, mode=str(kwargs.get("mode"))) as fp: 37 args = (fp.get_local_path(),) + args[1:] 38 39 return _np.memmap(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]
40 41
[docs] 42def load(*args: Any, **kwargs: Any) -> Union[_np.ndarray, Dict[str, _np.ndarray], _np.lib.npyio.NpzFile]: 43 """ 44 Adapt ``numpy.load``. 45 """ 46 47 file = args[0] if args else kwargs.get("file") 48 if isinstance(file, str) and file.startswith(MSC_PROTOCOL): 49 with msc_open(file) as fp: 50 # For .npy with memmap mode != none, _np.load() will call format.open_memmap() underneath, 51 # Which require a file path string 52 # Refs: 53 # https://github.com/numpy/numpy/blob/main/numpy/lib/_npyio_impl.py#L477 54 # https://numpy.org/doc/stable/reference/generated/numpy.lib.format.open_memmap.html 55 # 56 # For the simplicity of the code, we always pass the file path to _np.load and let it convert the path 57 # to file-like object. 58 59 # block until download is completed to ensure local path is available for the open() call within _np.load() 60 local_path = fp.get_local_path() 61 if not local_path: 62 raise ValueError(f"local_path={local_path} for the downloaded file[{file}] is not valid") 63 if args: 64 args = (local_path,) + args[1:] 65 else: 66 kwargs["file"] = local_path 67 68 return _np.load(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]
69 70
[docs] 71def save(*args: Any, **kwargs: Any) -> None: 72 """ 73 Adapt ``numpy.save``. 74 """ 75 76 file = args[0] if args else kwargs.get("file") 77 if isinstance(file, str) and file.startswith(MSC_PROTOCOL): 78 # use context manager to make sure to upload the file once close() is called 79 with msc_open(file, mode="wb") as fp: 80 if args: 81 args = (fp,) + args[1:] 82 else: 83 kwargs["file"] = fp 84 85 _np.save(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue] 86 else: 87 _np.save(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]