Source code for multistorageclient.contrib.numpy

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16from typing import Any, Union
 17
 18import numpy as _np
 19
 20from ..pathlib import MultiStoragePath
 21from ..shortcuts import open as msc_open
 22
 23
[docs] 24def memmap(*args: Any, **kwargs: Any) -> _np.memmap: 25 """ 26 Adapt ``numpy.memmap``. 27 """ 28 29 if not args: 30 raise TypeError("missing filename argument") 31 file = args[0] 32 33 if isinstance(file, str): 34 if "mode" not in kwargs: 35 kwargs["mode"] = "r" 36 with msc_open(file, mode=str(kwargs.get("mode"))) as fp: 37 args = (fp.resolve_filesystem_path(),) + args[1:] 38 elif isinstance(file, MultiStoragePath): 39 if "mode" not in kwargs: 40 kwargs["mode"] = "r" 41 with file.open(mode=str(kwargs.get("mode"))) as fp: 42 args = (fp.resolve_filesystem_path(),) + args[1:] 43 44 return _np.memmap(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]
45 46
[docs] 47def load(*args: Any, **kwargs: Any) -> Union[_np.ndarray, dict[str, _np.ndarray], _np.lib.npyio.NpzFile]: 48 """ 49 Adapt ``numpy.load``. 50 """ 51 52 file = args[0] if args else kwargs.get("file") 53 54 def resolve_filesystem_path(file: Union[str, MultiStoragePath]) -> Union[str, None]: 55 """Helper function to get the local path from a filepath or MultiStoragePath.""" 56 if isinstance(file, str): 57 file = MultiStoragePath(file) 58 return file.as_posix() 59 60 if isinstance(file, str) or isinstance(file, MultiStoragePath): 61 # For .npy with memmap mode != none, _np.load() will call format.open_memmap() underneath, 62 # Which require a file path string 63 # Refs: 64 # https://github.com/numpy/numpy/blob/main/numpy/lib/_npyio_impl.py#L477 65 # https://numpy.org/doc/stable/reference/generated/numpy.lib.format.open_memmap.html 66 # 67 # For the simplicity of the code, we always pass the file path to _np.load and let it convert the path 68 # to file-like object. 69 70 # block until download is completed to ensure local path is available for the open() call within _np.load() 71 local_path = resolve_filesystem_path(file) 72 if not local_path: 73 raise ValueError(f"local_path={local_path} for the downloaded file[{file}] is not valid") 74 if args: 75 args = (local_path,) + args[1:] 76 else: 77 kwargs["file"] = local_path 78 79 return _np.load(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]
80 81
[docs] 82def save(*args: Any, **kwargs: Any) -> None: 83 """ 84 Adapt ``numpy.save``. 85 """ 86 87 file = args[0] if args else kwargs.get("file") 88 if isinstance(file, str): 89 # use context manager to make sure to upload the file once close() is called 90 with msc_open(file, mode="wb") as fp: 91 if args: 92 args = (fp,) + args[1:] 93 else: 94 kwargs["file"] = fp 95 96 _np.save(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue] 97 elif isinstance(file, MultiStoragePath): 98 with file.open(mode="wb") as fp: 99 if args: 100 args = (fp,) + args[1:] 101 else: 102 kwargs["file"] = fp 103 104 _np.save(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue] 105 else: 106 _np.save(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]