Source code for multistorageclient.contrib.numpy
1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from typing import Any, Union
17
18import numpy as _np
19
20from ..pathlib import MultiStoragePath
21from ..shortcuts import open as msc_open
22
23
[docs]
24def memmap(*args: Any, **kwargs: Any) -> _np.memmap:
25 """
26 Adapt ``numpy.memmap``.
27 """
28
29 if not args:
30 raise TypeError("missing filename argument")
31 file = args[0]
32
33 if isinstance(file, str):
34 if "mode" not in kwargs:
35 kwargs["mode"] = "r"
36 with msc_open(file, mode=str(kwargs.get("mode"))) as fp:
37 args = (fp.resolve_filesystem_path(),) + args[1:]
38 elif isinstance(file, MultiStoragePath):
39 if "mode" not in kwargs:
40 kwargs["mode"] = "r"
41 with file.open(mode=str(kwargs.get("mode"))) as fp:
42 args = (fp.resolve_filesystem_path(),) + args[1:]
43
44 return _np.memmap(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]
45
46
[docs]
47def load(*args: Any, **kwargs: Any) -> Union[_np.ndarray, dict[str, _np.ndarray], _np.lib.npyio.NpzFile]:
48 """
49 Adapt ``numpy.load``.
50 """
51
52 file = args[0] if args else kwargs.get("file")
53
54 def resolve_filesystem_path(file: Union[str, MultiStoragePath]) -> Union[str, None]:
55 """Helper function to get the local path from a filepath or MultiStoragePath."""
56 if isinstance(file, str):
57 file = MultiStoragePath(file)
58 return file.as_posix()
59
60 if isinstance(file, str) or isinstance(file, MultiStoragePath):
61 # For .npy with memmap mode != none, _np.load() will call format.open_memmap() underneath,
62 # Which require a file path string
63 # Refs:
64 # https://github.com/numpy/numpy/blob/main/numpy/lib/_npyio_impl.py#L477
65 # https://numpy.org/doc/stable/reference/generated/numpy.lib.format.open_memmap.html
66 #
67 # For the simplicity of the code, we always pass the file path to _np.load and let it convert the path
68 # to file-like object.
69
70 # block until download is completed to ensure local path is available for the open() call within _np.load()
71 local_path = resolve_filesystem_path(file)
72 if not local_path:
73 raise ValueError(f"local_path={local_path} for the downloaded file[{file}] is not valid")
74 if args:
75 args = (local_path,) + args[1:]
76 else:
77 kwargs["file"] = local_path
78
79 return _np.load(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]
80
81
[docs]
82def save(*args: Any, **kwargs: Any) -> None:
83 """
84 Adapt ``numpy.save``.
85
86 :param attributes: Optional dictionary of custom attributes/metadata to attach to the file.
87 """
88 attributes_dict = {}
89 if "attributes" in kwargs:
90 attributes_dict = kwargs.pop("attributes")
91
92 file = args[0] if args else kwargs.get("file")
93 if isinstance(file, str):
94 with msc_open(file, mode="wb", attributes=attributes_dict) as fp:
95 if args:
96 args = (fp,) + args[1:]
97 else:
98 kwargs["file"] = fp
99
100 _np.save(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]
101 elif isinstance(file, MultiStoragePath):
102 with file.open(mode="wb", attributes=attributes_dict) as fp:
103 if args:
104 args = (fp,) + args[1:]
105 else:
106 kwargs["file"] = fp
107
108 _np.save(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]
109 else:
110 _np.save(*args, **kwargs) # pyright: ignore [reportArgumentType, reportCallIssue]