Coverage for cuda / core / _memoryview.pyx: 62.99%
708 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7from ._dlpack cimport *
8from ._dlpack import classify_dl_device
9from libc.stdint cimport intptr_t
10from cuda.core._layout cimport _StridedLayout, get_strides_ptr
11from cuda.core._stream import Stream
13import ctypes
14import functools
15import sys
16import warnings
18import numpy
20from cuda.bindings cimport cydriver
21from cuda.core._resource_handles cimport (
22 EventHandle,
23 create_event_handle_noctx,
24 as_cu,
25)
27from cuda.core._utils.cuda_utils import handle_return, driver
28from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
31from cuda.core._memory import Buffer
34# ---------------------------------------------------------------------------
35# Lazy tensor bridge (avoids loading _tensor_bridge.so until torch is used)
36# ---------------------------------------------------------------------------
38cdef object _tensor_bridge = None
39# Cache: type(obj) -> True/False for the torch tensor check.
40# Once a type is seen, we never re-check.
41cdef dict _torch_type_cache = {}
42# Tri-state: None = not checked, True/False = result of version check
43cdef object _torch_version_ok = None
45cdef inline bint _torch_version_check():
46 """Return True if 2.3 <= torch <= 2.11 (known AOTI ABI range). Memoized.
48 Lower bound: AOTI functions we use were introduced in PyTorch 2.3.
49 Upper bound: the ``pyobj_to_aten_handle`` trick relies on the
50 THPVariable struct layout (PyObject_HEAD followed by at::Tensor cdata)
51 and the identity ``AtenTensorHandle == at::Tensor*``. Both are
52 undocumented internals that could change in a future PyTorch version.
53 We cap at the latest version we have tested against; unknown versions
54 fall back to the standard DLPack/CAI paths. Bump the upper bound
55 after verifying a new PyTorch release.
56 """
57 global _torch_version_ok
58 if _torch_version_ok is not None:
59 return <bint>_torch_version_ok
60 torch = sys.modules.get("torch")
61 if torch is None:
62 _torch_version_ok = False
63 return False
64 try:
65 major, minor = int(torch.__version__.split(".")[0]), \
66 int(torch.__version__.split(".")[1])
67 _torch_version_ok = (2, 3) <= (major, minor) <= (2, 11)
68 except (ValueError, IndexError):
69 _torch_version_ok = False
70 return <bint>_torch_version_ok
73cdef inline bint _is_torch_tensor(object obj):
74 cdef type tp = type(obj) 2F D E J x m n o p q s t u v w c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU H I a l G k e f g h i j
75 cdef object cached = _torch_type_cache.get(tp) 2F D E J x m n o p q s t u v w c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU H I a l G k e f g h i j
76 if cached is not None: 2F D E J x m n o p q s t u v w c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU H I a l G k e f g h i j
77 return <bint>cached 2F E J x m n o p q s t u v w c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jbH I l G k e f g h i j
78 cdef str mod = tp.__module__ or "" 2D 9 hbU a
79 cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \ 2D 9 hbU a
80 and _torch_version_check()
81 _torch_type_cache[tp] = result 2D 9 hbU a
82 return result 2D 9 hbU a
85cdef object _get_tensor_bridge():
86 """Bootstrap AOTI symbols, then import _tensor_bridge on first use."""
87 global _tensor_bridge
88 if _tensor_bridge is not None:
89 return _tensor_bridge
90 torch_C = sys.modules.get("torch._C")
91 if torch_C is None:
92 raise RuntimeError(
93 "torch._C is not loaded; cannot initialise the tensor bridge. "
94 "Make sure PyTorch is imported before passing a torch.Tensor.")
95 ctypes.CDLL(torch_C.__file__, mode=ctypes.RTLD_GLOBAL)
96 from cuda.core import _tensor_bridge as tb
97 _tensor_bridge = tb
98 return _tensor_bridge
101try:
102 from ml_dtypes import bfloat16
103except ImportError:
104 bfloat16 = None
106# TODO(leofang): support NumPy structured dtypes
109cdef extern from "Python.h":
110 ctypedef struct PyTypeObject:
111 void* tp_dict
112 void PyType_Modified(PyTypeObject*)
115cdef DLPackExchangeAPI _SMV_DLPACK_EXCHANGE_API
116cdef bint _SMV_DLPACK_EXCHANGE_API_INITED = False
117_SMV_DLPACK_EXCHANGE_API_CAPSULE = cpython.PyCapsule_New(
118 <void*>&_SMV_DLPACK_EXCHANGE_API,
119 b"dlpack_exchange_api",
120 NULL,
121)
124cdef class StridedMemoryView:
125 """A class holding metadata of a strided dense array/tensor.
127 A :obj:`StridedMemoryView` instance can be created in three ways:
129 1. Using the :obj:`args_viewable_as_strided_memory` decorator (recommended)
130 2. Explicit construction relying on DLPack or CUDA Array Interface, see below.
131 3. From :obj:`~_memory.Buffer` and shape and size tuples (see
132 :meth:`from_buffer` classmethod)
134 ``StridedMemoryView(obj, stream_ptr)`` can be used to create a view from
135 objects supporting either DLPack (up to v1.0) or CUDA Array Interface
136 (CAI) v3. When wrapping an arbitrary object it will try the DLPack protocol
137 first, then the CAI protocol. A :obj:`BufferError` is raised if neither is
138 supported.
140 Since either way would take a consumer stream, for DLPack it is passed to
141 ``obj.__dlpack__()`` as-is (except for :obj:`None`, see below); for CAI, a
142 stream order will be established between the consumer stream and the
143 producer stream (from ``obj.__cuda_array_interface__()["stream"]``), as if
144 ``cudaStreamWaitEvent`` is called by this method.
146 To opt-out of the stream ordering operation in either DLPack or CAI,
147 please pass ``stream_ptr=-1``. Note that this deviates (on purpose)
148 from the semantics of ``obj.__dlpack__(stream=None, ...)`` since ``cuda.core``
149 does not encourage using the (legacy) default/null stream, but is
150 consistent with the CAI's semantics. For DLPack, ``stream=-1`` will be
151 internally passed to ``obj.__dlpack__()`` instead.
153 Parameters
154 ----------
155 obj : Any
156 Any objects that supports either DLPack (up to v1.0) or CUDA Array
157 Interface (v3).
158 stream_ptr: int
159 The pointer address (as Python `int`) to the **consumer** stream.
160 Stream ordering will be properly established unless ``-1`` is passed.
163 Attributes
164 -----------
165 ptr : int
166 Pointer to the tensor buffer (as a Python `int`).
167 device_id : int
168 The device ID for where the tensor is located. It is -1 for CPU tensors
169 (meaning those only accessible from the host).
170 is_device_accessible : bool
171 Whether the tensor data can be accessed on the GPU.
172 readonly: bool
173 Whether the tensor data can be modified in place.
174 exporting_obj : Any
175 A reference to the original tensor object that is being viewed.
176 If the view is created with :meth:`from_buffer`,
177 it will be the Buffer instance passed to the method.
179 """
180 def __init__(self, obj: object = None, stream_ptr: int | None = None) -> None:
181 cdef str clsname = self.__class__.__name__ 2y z A B C lb
182 if obj is not None: 2y z A B C lb
183 # populate self's attributes
184 if check_has_dlpack(obj): 1yzABC
185 warnings.warn( 1dyzABC
186 f"Constructing a {clsname} directly from a DLPack-supporting object is deprecated; " 1yzABC
187 "Use `StridedMemoryView.from_dlpack` or `StridedMemoryView.from_any_interface` instead.",
188 DeprecationWarning, 1yzABC
189 stacklevel=2,
190 )
191 view_as_dlpack(obj, stream_ptr, self) 1yzABC
192 else:
193 warnings.warn(
194 f"Constructing a {clsname} directly from a CUDA-array-interface-supporting object is deprecated; "
195 "Use `StridedMemoryView.from_cuda_array_interface` or `StridedMemoryView.from_any_interface` instead.",
196 DeprecationWarning,
197 stacklevel=2,
198 )
199 view_as_cai(obj, stream_ptr, self)
200 else:
201 warnings.warn( 2d lb
202 f"Constructing an empty {clsname} is deprecated; " 2lb
203 "use one of the classmethods `from_dlpack`, `from_cuda_array_interface` or `from_any_interface` "
204 "to construct a StridedMemoryView from an object",
205 DeprecationWarning, 2lb
206 stacklevel=2,
207 )
209 @classmethod
210 def from_dlpack(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
211 """Create a view from an object supporting the `DLPack <https://dmlc.github.io/dlpack/latest/>`_ protocol.
213 Parameters
214 ----------
215 obj : object
216 An object implementing the `DLPack <https://dmlc.github.io/dlpack/latest/>`_ protocol
217 (via ``__dlpack__``).
218 stream_ptr : int, optional
219 Stream pointer for synchronization. If ``None``, no synchronization is performed.
220 """
221 cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) 1FDEJxmnopqstuvwcbHIalGkefghij
222 if _is_torch_tensor(obj): 1FDEJxmnopqstuvwcbHIalGkefghij
223 _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf)
224 return buf
225 view_as_dlpack(obj, stream_ptr, buf) 1FDEJxmnopqstuvwcbHIalGkefghij
226 return buf 1FDEJxmnopqstuvwcbHIalGkefghij
228 @classmethod
229 def from_cuda_array_interface(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
230 """Create a view from an object supporting the `__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_ protocol.
232 Parameters
233 ----------
234 obj : object
235 An object implementing the `__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_ protocol.
236 stream_ptr : int, optional
237 Stream pointer for synchronization. If ``None``, no synchronization is performed.
238 """
239 cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) 29 hbU
240 if _is_torch_tensor(obj): 29 hbU
241 _get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf)
242 return buf
243 view_as_cai(obj, stream_ptr, buf) 29 hbU
244 return buf 19U
246 @classmethod
247 def from_array_interface(cls, obj: object) -> StridedMemoryView:
248 """Create a view from an object supporting the `__array_interface__ <https://numpy.org/doc/stable/reference/arrays.interface.html>`_ protocol.
250 Parameters
251 ----------
252 obj : object
253 An object implementing the `__array_interface__ <https://numpy.org/doc/stable/reference/arrays.interface.html>`_ protocol (e.g., a numpy array).
254 """
255 cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
256 if _is_torch_tensor(obj): 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
257 _get_tensor_bridge().view_as_torch_tensor(obj, None, buf)
258 return buf
259 view_as_array_interface(obj, buf) 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
260 return buf 1KLVMWNXOPQRYZ0123S45678T
262 @classmethod
263 def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView:
264 """Create a view by automatically selecting the best available protocol.
266 Tries `DLPack <https://dmlc.github.io/dlpack/latest/>`_ first, then falls back to
267 `__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_.
268 ``torch.Tensor`` objects are transparently handled via a fast AOTI path
269 regardless of which protocol is selected.
271 Parameters
272 ----------
273 obj : object
274 An object implementing `DLPack <https://dmlc.github.io/dlpack/latest/>`_ or
275 `__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_.
276 stream_ptr : int, optional
277 Stream pointer for synchronization. If ``None``, no synchronization is performed.
278 """
279 if check_has_dlpack(obj): 1FDEJxstuvwcbHIalGk
280 return cls.from_dlpack(obj, stream_ptr) 1FDEJxstuvwcbHIalGk
281 return cls.from_cuda_array_interface(obj, stream_ptr)
283 @classmethod
284 def from_buffer(
285 cls,
286 buffer : Buffer,
287 shape : tuple[int, ...],
288 strides : tuple[int, ...] | None = None,
289 *,
290 itemsize : int | None = None,
291 dtype : numpy.dtype | None = None,
292 is_readonly : bool = False
293 ) -> StridedMemoryView:
294 """
295 Creates a :obj:`StridedMemoryView` instance from a :obj:`~_memory.Buffer` and shape and strides tuples.
296 The Buffer can be either allocation coming from a :obj:`MemoryResource` or an external allocation
297 wrapped in a :obj:`~_memory.Buffer` object with ``Buffer.from_handle(ptr, size, owner=...)``.
299 .. caution::
300 When creating a :obj:`StridedMemoryView` from a :obj:`~_memory.Buffer`,
301 no synchronization is performed. It is the user's responsibility to ensure
302 the data in ``buffer`` is properly synchronized when consuming the view.
304 Parameters
305 ----------
306 buffer : :obj:`~_memory.Buffer`
307 The buffer to create the view from.
308 shape : :obj:`tuple`
309 The layout describing the shape, strides and itemsize of the elements in
310 the buffer.
311 strides : :obj:`tuple`
312 The layout describing the shape, strides and itemsize of the elements in
313 the buffer.
314 dtype : :obj:`numpy.dtype`
315 Optional dtype.
316 If specified, the dtype's itemsize must match the layout's itemsize.
317 is_readonly : bool, optional
318 Whether the mark the view as readonly.
319 """
320 cdef StridedMemoryView view = StridedMemoryView.__new__(cls) 2obpb$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmbqbnb! # kbibr
321 if itemsize is None and dtype is None: 2obpb$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmbqbnb! # kbibr
322 raise ValueError("Either itemsize or dtype must be specified") 2qb
323 if itemsize is not None and dtype is not None and itemsize != dtype.itemsize: 2obpb$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmbnb! # kbibr
324 raise ValueError( 2nb
325 f"itemsize ({itemsize}) does not match dtype.itemsize ({dtype.itemsize})" 2nb
326 )
327 # (itemsize is None XOR dtype is None) OR they are equal
328 view_buffer_strided( 2$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmb! # kbibr
329 view,
330 buffer,
331 _StridedLayout(shape=shape, strides=strides, itemsize=getattr(dtype, "itemsize", itemsize)), 2obpb$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmb! # kbibr
332 dtype,
333 is_readonly,
334 )
335 return view 2$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr
337 def __dealloc__(self):
338 if self.dl_tensor == NULL: 2d obpbF D E J x m n o p q s t u v w y z A B C c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmbqbnb! # kbib9 hbU H I a r lbl G k e f g h i j
339 return 2obpbc b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmbqbnb! # kbib9 hbU r lbG k e f g h i j
341 if cpython.PyCapsule_IsValid( 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
342 self.metadata, DLPACK_VERSIONED_TENSOR_USED_NAME): 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
343 data = cpython.PyCapsule_GetPointer( 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
344 self.metadata, DLPACK_VERSIONED_TENSOR_USED_NAME) 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
345 dlm_tensor_ver = <DLManagedTensorVersioned*>data 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
346 dlm_tensor_ver.deleter(dlm_tensor_ver) 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
347 elif cpython.PyCapsule_IsValid(
348 self.metadata, DLPACK_TENSOR_USED_NAME):
349 data = cpython.PyCapsule_GetPointer(
350 self.metadata, DLPACK_TENSOR_USED_NAME)
351 dlm_tensor = <DLManagedTensor*>data
352 dlm_tensor.deleter(dlm_tensor)
354 def view(
355 self, layout : _StridedLayout | None = None, dtype : numpy.dtype | None = None
356 ) -> StridedMemoryView:
357 """
358 Creates a new view with adjusted layout and dtype.
359 Same as calling :meth:`from_buffer` with the current buffer.
360 """
361 cdef StridedMemoryView view = StridedMemoryView.__new__(self.__class__) 1cb!#Gkefghij
362 if layout is None and dtype is None: 1cb!#Gkefghij
363 return self 1G
364 if layout is None: 1cb!#kefghij
365 layout = self.get_layout() 1dcbk
366 if dtype is None: 1cb!#kefghij
367 dtype = self.get_dtype() 1!#efghij
368 view_buffer_strided(view, self.get_buffer(), layout, dtype, self.readonly) 1cb!#kefghij
369 return view 1cb!#kefghij
371 def as_tensor_map(
372 self,
373 box_dim=None,
374 *,
375 options=None,
376 element_strides=None,
377 data_type=None,
378 interleave=None,
379 swizzle=None,
380 l2_promotion=None,
381 oob_fill=None,
382 ):
383 """Create a tiled :obj:`TensorMapDescriptor` from this view.
385 This is the public entry point for creating tiled tensor map
386 descriptors in ``cuda.core``. Pass either ``box_dim`` and the
387 individual keyword arguments directly, or provide bundled tiled
388 options via ``options=``.
389 """
390 from cuda.core._tensor_map import TensorMapDescriptor
392 kwargs = {}
393 if options is not None:
394 kwargs["options"] = options
395 if element_strides is not None:
396 kwargs["element_strides"] = element_strides
397 if data_type is not None:
398 kwargs["data_type"] = data_type
399 if interleave is not None:
400 kwargs["interleave"] = interleave
401 if swizzle is not None:
402 kwargs["swizzle"] = swizzle
403 if l2_promotion is not None:
404 kwargs["l2_promotion"] = l2_promotion
405 if oob_fill is not None:
406 kwargs["oob_fill"] = oob_fill
407 return TensorMapDescriptor._from_tiled(self, box_dim, **kwargs)
409 def copy_from(
410 self, other : StridedMemoryView, stream : Stream,
411 allocator = None,
412 blocking : bool | None = None,
413 ):
414 """
415 Copies the data from the other view into this view.
417 The copy can be performed between following memory spaces:
418 host-to-device, device-to-host, device-to-device (on the same device).
420 Parameters
421 ----------
422 other : StridedMemoryView
423 The view to copy data from.
424 stream : Stream | None, optional
425 The stream to schedule the copy on.
426 allocator : MemoryResource | None, optional
427 If temporary buffers are needed, the specified memory resources
428 will be used to allocate the memory. If not specified, default
429 resources will be used.
430 blocking : bool | None, optional
431 Whether the call should block until the copy is complete.
432 * ``True``: the ``stream`` is synchronized with the host at the end of the call,
433 blocking until the copy is complete.
434 * ``False``: if possible, the call returns immediately once the copy is scheduled.
435 However, in some cases of host-to-device or device-to-host copies, the call may
436 still synchronize with the host if necessary.
437 * ``None`` (default):
438 * for device-to-device, it defaults to ``False`` (non-blocking),
439 * for host-to-device or device-to-host, it defaults to ``True`` (blocking).
440 """
441 raise NotImplementedError("Sorry, not supported: copy_from") 1dH
443 def copy_to(
444 self, other : StridedMemoryView, stream : Stream | None = None,
445 allocator = None,
446 blocking : bool | None = None,
447 ):
448 """
449 Copies the data from this view into the ``other`` view.
451 For details, see :meth:`copy_from`.
452 """
453 raise NotImplementedError("Sorry, not supported: copy_to") 1I
455 def __dlpack__(
456 self,
457 *,
458 stream: int | None = None,
459 max_version: tuple[int, int] | None = None,
460 dl_device: tuple[int, int] | None = None,
461 copy: bool | None = None,
462 ):
463 # Similar to Buffer.__dlpack__: no implicit synchronization is performed.
464 if dl_device is not None: 1Ecbar
465 raise BufferError("Sorry, not supported: dl_device other than None") 1E
466 if copy is True: 1Ecbar
467 raise BufferError("Sorry, not supported: copy=True") 1E
469 cdef bint versioned
470 if max_version is None: 1Ecbar
471 versioned = False 1cbr
472 else:
473 if not isinstance(max_version, tuple) or len(max_version) != 2: 1Ea
474 raise BufferError(f"Expected max_version tuple[int, int], got {max_version}") 1E
475 versioned = max_version >= (1, 0) 1a
477 # NOTE: stream is accepted for protocol compatibility but not used.
478 cdef object capsule = _smv_make_py_capsule(self, versioned) 1cbar
479 return capsule 1a
481 def __dlpack_device__(self) -> tuple[int, int]:
482 cdef _DLDeviceType device_type
483 cdef int32_t device_id
484 _smv_get_dl_device(self, &device_type, &device_id) 1FDa
485 return (<int>device_type, int(device_id)) 1FDa
487 @property
488 def _layout(self) -> _StridedLayout:
489 """
490 The layout of the tensor. For StridedMemoryView created from DLPack or CAI,
491 the layout is inferred from the tensor object's metadata.
492 """
493 return self.get_layout() 2$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # lbk e f g h i j
495 @property
496 def size(self) -> int:
497 return self.get_layout().get_volume() 1dmnopqstuvwyzABCKLVMWNXOPQRYZ0123S45678T9
499 @property
500 def shape(self) -> tuple[int, ...]:
501 """
502 Shape of the tensor.
503 """
504 return self.get_layout().get_shape_tuple() 2x m n o p q s t u v w y z A B C K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # 9 U l e f g h i j
506 @property
507 def strides(self) -> tuple[int, ...] | None:
508 """
509 Strides of the tensor (in **counts**, not bytes).
510 """
511 return self.get_layout().get_strides_tuple() 2x m n o p q s t u v w y z A B C K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb9 U l
513 @property
514 def dtype(self) -> numpy.dtype | None:
515 """
516 Data type of the tensor.
518 Supports standard NumPy dtypes as well as narrow data types (e.g., ``bfloat16``)
519 when the optional `ml_dtypes <https://github.com/jax-ml/ml_dtypes>`_ package is
520 installed. If ``ml_dtypes`` is not available and such a tensor is encountered,
521 a :obj:`NotImplementedError` will be raised.
522 """
523 return self.get_dtype() 2d K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbibl k e f g h i j
525 def __repr__(self):
526 return (f"StridedMemoryView(ptr={self.ptr},\n" 1l
527 + f" shape={self.shape},\n" 1l
528 + f" strides={self.strides},\n" 1l
529 + f" itemsize={self._layout.itemsize},\n" 1l
530 + f" dtype={get_simple_repr(self.dtype)},\n" 1l
531 + f" device_id={self.device_id},\n" 1l
532 + f" is_device_accessible={self.is_device_accessible},\n" 1l
533 + f" readonly={self.readonly},\n" 1l
534 + f" exporting_obj={get_simple_repr(self.exporting_obj)})") 1l
536 cdef inline _StridedLayout get_layout(self):
537 if self._layout is None: 2x m n o p q s t u v w y z A B C c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb$ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # 9 hbU a r lbl k e f g h i j
538 if self.dl_tensor: 2x m n o p q s t u v w y z A B C c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU a lbl k e f g h i j
539 self._layout = layout_from_dlpack(self.dl_tensor) 1xmnopqstuvwyzABCcbalkefghij
540 elif self.metadata is not None: 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU lb
541 self._layout = layout_from_cai(self.metadata) 2d K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU
542 else:
543 raise ValueError("Cannot infer layout from the exporting object") 2lb
544 return self._layout 2x m n o p q s t u v w y z A B C c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # 9 U a r l k e f g h i j
546 cdef inline object get_buffer(self):
547 """
548 Returns Buffer instance with the underlying data.
549 If the SMV was created from a Buffer, it will return the same Buffer instance.
550 Otherwise, it will create a new instance with owner set to the exporting object.
551 """
552 if self._buffer is None: 1cb!#kefghij
553 if isinstance(self.exporting_obj, Buffer): 1cbkefghij
554 self._buffer = self.exporting_obj
555 else:
556 self._buffer = Buffer.from_handle(self.ptr, 0, owner=self.exporting_obj) 1cbkefghij
557 return self._buffer 1cb!#kefghij
559 cdef inline object get_dtype(self):
560 if self._dtype is None: 2c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # iba r l k e f g h i j
561 if self.dl_tensor != NULL: 1KLVMWNXOPQRYZ0123S45678Tarlefghij
562 self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) 1alefghij
563 elif isinstance(self.metadata, int): 1KLVMWNXOPQRYZ0123S45678Tr
564 # AOTI dtype code stored by the torch tensor bridge
565 self._dtype = _get_tensor_bridge().resolve_aoti_dtype(
566 self.metadata)
567 elif self.metadata is not None: 1KLVMWNXOPQRYZ0123S45678Tr
568 self._dtype = _typestr2dtype(self.metadata["typestr"]) 1KLVMWNXOPQRYZ0123S45678T
569 return self._dtype 2c b K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # iba r l k e f g h i j
572cdef void _smv_pycapsule_deleter(object capsule) noexcept:
573 cdef DLManagedTensor* dlm_tensor
574 cdef DLManagedTensorVersioned* dlm_tensor_ver
575 # Do not invoke the deleter on a used capsule.
576 if cpython.PyCapsule_IsValid(capsule, DLPACK_TENSOR_UNUSED_NAME): 1a
577 dlm_tensor = <DLManagedTensor*>(
578 cpython.PyCapsule_GetPointer(capsule, DLPACK_TENSOR_UNUSED_NAME)
579 )
580 if dlm_tensor.deleter:
581 dlm_tensor.deleter(dlm_tensor)
582 elif cpython.PyCapsule_IsValid(capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME): 1a
583 dlm_tensor_ver = <DLManagedTensorVersioned*>(
584 cpython.PyCapsule_GetPointer(capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
585 )
586 if dlm_tensor_ver.deleter:
587 dlm_tensor_ver.deleter(dlm_tensor_ver)
590cdef inline void _smv_release_export_resources(void* manager_ctx, int64_t* shape_ptr) noexcept with gil:
591 if shape_ptr: 1dcbar
592 stdlib.free(shape_ptr) 1a
593 if manager_ctx: 1cbar
594 cpython.Py_DECREF(<object>manager_ctx) 1cbar
597cdef void _smv_deleter(DLManagedTensor* tensor) noexcept with gil:
598 if tensor: 1cbr
599 _smv_release_export_resources(tensor.manager_ctx, tensor.dl_tensor.shape) 1cbr
600 tensor.manager_ctx = NULL 1cbr
601 stdlib.free(tensor) 1cbr
604cdef void _smv_versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
605 if tensor: 1cbar
606 _smv_release_export_resources(tensor.manager_ctx, tensor.dl_tensor.shape) 1a
607 tensor.manager_ctx = NULL 1a
608 stdlib.free(tensor) 1a
611cdef inline DLManagedTensorVersioned* _smv_allocate_dlm_tensor_versioned() except? NULL:
612 cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL 1a
613 dlm_tensor_ver = <DLManagedTensorVersioned*>stdlib.malloc(sizeof(DLManagedTensorVersioned)) 1da
614 if dlm_tensor_ver == NULL: 1a
615 raise MemoryError()
616 dlm_tensor_ver.dl_tensor.shape = NULL 1a
617 dlm_tensor_ver.manager_ctx = NULL 1a
618 return dlm_tensor_ver 1a
621cdef inline DLManagedTensor* _smv_allocate_dlm_tensor() except? NULL:
622 cdef DLManagedTensor* dlm_tensor = NULL 1cbr
623 dlm_tensor = <DLManagedTensor*>stdlib.malloc(sizeof(DLManagedTensor)) 1cbr
624 if dlm_tensor == NULL: 1cbr
625 raise MemoryError()
626 dlm_tensor.dl_tensor.shape = NULL 1cbr
627 dlm_tensor.manager_ctx = NULL 1cbr
628 return dlm_tensor 1cbr
631cdef inline int _smv_dtype_numpy_to_dlpack(object dtype_obj, DLDataType* out_dtype) except -1:
632 cdef object np_dtype = numpy.dtype(dtype_obj) 1dcba
633 if np_dtype.fields is not None: 1cba
634 raise BufferError("Structured dtypes are not supported for DLPack export") 1c
635 if not np_dtype.isnative and np_dtype.byteorder not in ("=", "|"): 1ba
636 raise BufferError("Non-native-endian dtypes are not supported for DLPack export")
638 cdef str kind = np_dtype.kind 1ba
639 cdef int bits = np_dtype.itemsize * 8 1ba
640 cdef uint8_t code
641 if kind == "b": 1ba
642 if bits != 8:
643 raise BufferError(f"Unsupported bool dtype itemsize: {np_dtype.itemsize}")
644 code = <uint8_t>kDLBool
645 elif kind == "i": 1ba
646 if bits not in (8, 16, 32, 64): 1a
647 raise BufferError(f"Unsupported signed integer dtype: {np_dtype}")
648 code = <uint8_t>kDLInt 1da
649 elif kind == "u": 1b
650 if bits not in (8, 16, 32, 64):
651 raise BufferError(f"Unsupported unsigned integer dtype: {np_dtype}")
652 code = <uint8_t>kDLUInt
653 elif kind == "f": 1b
654 if bits not in (16, 32, 64):
655 raise BufferError(f"Unsupported floating dtype: {np_dtype}")
656 code = <uint8_t>kDLFloat
657 elif kind == "c": 1b
658 if bits not in (64, 128):
659 raise BufferError(f"Unsupported complex dtype: {np_dtype}")
660 code = <uint8_t>kDLComplex
661 else:
662 raise BufferError(f"Unsupported dtype for DLPack export: {np_dtype}") 1b
664 out_dtype.code = code 1a
665 out_dtype.bits = <uint8_t>bits 1a
666 out_dtype.lanes = <uint16_t>1 1a
667 return 0 1a
670cdef inline int _smv_get_dl_device(
671 StridedMemoryView view,
672 _DLDeviceType* out_device_type,
673 int32_t* out_device_id,
674) except -1:
675 cdef _DLDeviceType device_type
676 cdef int32_t device_id
677 cdef object buf
678 if view.dl_tensor != NULL: 1dFDa
679 device_type = view.dl_tensor.device.device_type 1FDa
680 if device_type == _kDLCUDA: 1FDa
681 device_id = view.dl_tensor.device.device_id
682 else:
683 # CPU, CUDAHost, and CUDAManaged use device_id=0 in DLPack.
684 device_id = 0 1FDa
685 elif view.is_device_accessible:
686 buf = view.get_buffer()
687 dev_type, dev_id = classify_dl_device(buf)
688 device_type = <_DLDeviceType>dev_type
689 device_id = <int32_t>dev_id
690 else:
691 device_type = _kDLCPU
692 device_id = 0
694 out_device_type[0] = device_type 1FDa
695 out_device_id[0] = device_id 1FDa
696 return 0 1FDa
699cdef inline int _smv_setup_dl_tensor_common(
700 DLTensor* dl_tensor,
701 StridedMemoryView view,
702 _StridedLayout layout,
703) except -1:
704 cdef object dtype_obj = view.get_dtype() 1cbar
705 if dtype_obj is None: 1cbar
706 raise BufferError( 1r
707 "Cannot export StridedMemoryView via DLPack without dtype information; "
708 "create the view with dtype specified."
709 )
710 _smv_dtype_numpy_to_dlpack(dtype_obj, &dl_tensor.dtype) 1cba
711 _smv_get_dl_device(view, &dl_tensor.device.device_type, &dl_tensor.device.device_id) 1a
713 cdef int ndim = layout.base.ndim 1a
714 dl_tensor.ndim = ndim 1a
715 if layout.get_volume() == 0: 1a
716 dl_tensor.data = NULL
717 else:
718 dl_tensor.data = <void*><intptr_t>view.ptr 1a
719 dl_tensor.byte_offset = 0 1a
720 return 0 1a
723cdef inline int _smv_setup_dl_tensor(DLTensor* dl_tensor, StridedMemoryView view) except -1:
724 cdef _StridedLayout layout = view.get_layout() 1cbar
725 _smv_setup_dl_tensor_common(dl_tensor, view, layout) 1cbar
727 cdef int i
728 cdef int64_t* shape_strides = NULL 1a
729 cdef int64_t* strides_src = NULL 1a
730 cdef int ndim = dl_tensor.ndim 1a
731 if ndim == 0: 1a
732 dl_tensor.shape = NULL
733 dl_tensor.strides = NULL
734 else:
735 # DLPack v1.2+ requires non-NULL strides for ndim != 0.
736 shape_strides = <int64_t*>stdlib.malloc(sizeof(int64_t) * 2 * ndim) 1a
737 if shape_strides == NULL: 1a
738 raise MemoryError()
739 try: 1a
740 strides_src = get_strides_ptr(layout.base) 1a
741 for i in range(ndim): 1a
742 shape_strides[i] = layout.base.shape[i] 1a
743 shape_strides[i + ndim] = strides_src[i] 1a
744 except Exception:
745 stdlib.free(shape_strides)
746 raise
747 dl_tensor.shape = shape_strides 1a
748 dl_tensor.strides = shape_strides + ndim 1a
749 return 0 1a
752cdef inline int _smv_setup_dltensor_borrowed(DLTensor* dl_tensor, StridedMemoryView view) except -1:
753 cdef _StridedLayout layout = view.get_layout()
754 _smv_setup_dl_tensor_common(dl_tensor, view, layout)
756 if dl_tensor.ndim == 0:
757 dl_tensor.shape = NULL
758 dl_tensor.strides = NULL
759 else:
760 dl_tensor.shape = layout.base.shape
761 # For temporary/non-owning exchange we provide explicit strides.
762 dl_tensor.strides = get_strides_ptr(layout.base)
763 return 0
766cdef inline int _smv_fill_managed_tensor_versioned(
767 DLManagedTensorVersioned* dlm_tensor_ver,
768 StridedMemoryView view,
769) except -1:
770 cpython.Py_INCREF(view) 1a
771 dlm_tensor_ver.manager_ctx = <void*>view 1a
772 dlm_tensor_ver.deleter = _smv_versioned_deleter 1a
773 dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION 1a
774 dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION 1a
775 dlm_tensor_ver.flags = DLPACK_FLAG_BITMASK_READ_ONLY if view.readonly else 0 1a
776 _smv_setup_dl_tensor(&dlm_tensor_ver.dl_tensor, view) 1a
777 return 0 1a
780cdef inline int _smv_fill_managed_tensor(
781 DLManagedTensor* dlm_tensor,
782 StridedMemoryView view,
783) except -1:
784 cpython.Py_INCREF(view) 1cbr
785 dlm_tensor.manager_ctx = <void*>view 1cbr
786 dlm_tensor.deleter = _smv_deleter 1cbr
787 _smv_setup_dl_tensor(&dlm_tensor.dl_tensor, view) 1cbr
788 return 0
791cdef object _smv_make_py_capsule(StridedMemoryView view, bint versioned):
792 cdef DLManagedTensor* dlm_tensor = NULL 1cbar
793 cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL 1cbar
794 cdef object capsule = None 1cbar
795 cdef void* tensor_ptr = NULL 1cbar
796 cdef const char* capsule_name
797 try: 1cbar
798 if versioned: 1cbar
799 dlm_tensor_ver = _smv_allocate_dlm_tensor_versioned() 1a
800 _smv_fill_managed_tensor_versioned(dlm_tensor_ver, view) 1a
801 tensor_ptr = <void*>dlm_tensor_ver 1a
802 capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME 1a
803 else:
804 dlm_tensor = _smv_allocate_dlm_tensor() 1cbr
805 _smv_fill_managed_tensor(dlm_tensor, view) 1cbr
806 tensor_ptr = <void*>dlm_tensor
807 capsule_name = DLPACK_TENSOR_UNUSED_NAME
808 capsule = cpython.PyCapsule_New(tensor_ptr, capsule_name, _smv_pycapsule_deleter) 1a
809 except Exception: 1cbr
810 if capsule is None: 1cbr
811 _smv_deleter(dlm_tensor) 1cbr
812 _smv_versioned_deleter(dlm_tensor_ver) 1cbr
813 raise 1cbr
814 return capsule 1a
817cdef inline StridedMemoryView _smv_from_dlpack_capsule(object capsule, object exporting_obj):
818 cdef void* data = NULL
819 cdef DLTensor* dl_tensor = NULL
820 cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL
821 cdef DLManagedTensor* dlm_tensor = NULL
822 cdef bint is_readonly = False
823 cdef const char* used_name = NULL
824 if cpython.PyCapsule_IsValid(capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
825 data = cpython.PyCapsule_GetPointer(capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
826 dlm_tensor_ver = <DLManagedTensorVersioned*>data
827 dl_tensor = &dlm_tensor_ver.dl_tensor
828 is_readonly = bool((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0)
829 used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
830 elif cpython.PyCapsule_IsValid(capsule, DLPACK_TENSOR_UNUSED_NAME):
831 data = cpython.PyCapsule_GetPointer(capsule, DLPACK_TENSOR_UNUSED_NAME)
832 dlm_tensor = <DLManagedTensor*>data
833 dl_tensor = &dlm_tensor.dl_tensor
834 is_readonly = False
835 used_name = DLPACK_TENSOR_USED_NAME
836 else:
837 raise BufferError("Invalid DLPack capsule")
839 cpython.PyCapsule_SetName(capsule, used_name)
841 cdef StridedMemoryView view = StridedMemoryView.__new__(StridedMemoryView)
842 view.dl_tensor = dl_tensor
843 view.metadata = capsule
844 view.ptr = <intptr_t>(dl_tensor.data) + <intptr_t>(dl_tensor.byte_offset)
845 view.readonly = is_readonly
846 view.exporting_obj = exporting_obj
847 if dl_tensor.device.device_type == _kDLCPU:
848 view.device_id = -1
849 view.is_device_accessible = False
850 elif dl_tensor.device.device_type in (_kDLCUDA, _kDLCUDAHost, _kDLCUDAManaged):
851 view.device_id = dl_tensor.device.device_id
852 view.is_device_accessible = True
853 else:
854 raise BufferError("device not supported")
855 return view
858cdef int _smv_managed_tensor_allocator(
859 DLTensor* prototype,
860 DLManagedTensorVersioned** out,
861 void* error_ctx,
862 void (*SetError)(void* error_ctx, const char* kind, const char* message) noexcept,
863) noexcept with gil:
864 if out != NULL:
865 out[0] = NULL
866 if SetError != NULL:
867 SetError(error_ctx, b"NotImplementedError", b"managed_tensor_allocator is not supported by StridedMemoryView")
868 cpython.PyErr_SetString(NotImplementedError, b"managed_tensor_allocator is not supported by StridedMemoryView")
869 return -1
872cdef int _smv_managed_tensor_from_py_object_no_sync(
873 void* py_object,
874 DLManagedTensorVersioned** out,
875) noexcept with gil:
876 cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL
877 if out == NULL:
878 cpython.PyErr_SetString(RuntimeError, b"out cannot be NULL")
879 return -1
880 out[0] = NULL
881 cdef object obj = <object>py_object
882 if not isinstance(obj, StridedMemoryView):
883 cpython.PyErr_SetString(TypeError, b"py_object must be a StridedMemoryView")
884 return -1
885 try:
886 dlm_tensor_ver = _smv_allocate_dlm_tensor_versioned()
887 _smv_fill_managed_tensor_versioned(dlm_tensor_ver, <StridedMemoryView>obj)
888 except Exception:
889 _smv_versioned_deleter(dlm_tensor_ver)
890 return -1
891 out[0] = dlm_tensor_ver
892 return 0
895cdef int _smv_managed_tensor_to_py_object_no_sync(
896 DLManagedTensorVersioned* tensor,
897 void** out_py_object,
898) noexcept with gil:
899 cdef object capsule
900 cdef object py_view
901 if out_py_object == NULL:
902 cpython.PyErr_SetString(RuntimeError, b"out_py_object cannot be NULL")
903 return -1
904 out_py_object[0] = NULL
905 if tensor == NULL:
906 cpython.PyErr_SetString(RuntimeError, b"tensor cannot be NULL")
907 return -1
908 try:
909 capsule = cpython.PyCapsule_New(
910 <void*>tensor,
911 DLPACK_VERSIONED_TENSOR_UNUSED_NAME,
912 _smv_pycapsule_deleter,
913 )
914 py_view = _smv_from_dlpack_capsule(capsule, capsule)
915 cpython.Py_INCREF(py_view)
916 out_py_object[0] = <void*>py_view
917 except Exception:
918 return -1
919 return 0
922cdef int _smv_dltensor_from_py_object_no_sync(
923 void* py_object,
924 DLTensor* out,
925) noexcept with gil:
926 if out == NULL:
927 cpython.PyErr_SetString(RuntimeError, b"out cannot be NULL")
928 return -1
929 cdef object obj = <object>py_object
930 if not isinstance(obj, StridedMemoryView):
931 cpython.PyErr_SetString(TypeError, b"py_object must be a StridedMemoryView")
932 return -1
933 try:
934 _smv_setup_dltensor_borrowed(out, <StridedMemoryView>obj)
935 except Exception:
936 return -1
937 return 0
940cdef int _smv_current_work_stream(
941 _DLDeviceType device_type,
942 int32_t device_id,
943 void** out_current_stream,
944) noexcept with gil:
945 if out_current_stream == NULL:
946 cpython.PyErr_SetString(RuntimeError, b"out_current_stream cannot be NULL")
947 return -1
948 # cuda.core has no global/current stream state today.
949 out_current_stream[0] = NULL
950 return 0
953cdef void _init_smv_dlpack_exchange_api():
954 global _SMV_DLPACK_EXCHANGE_API_INITED
955 if _SMV_DLPACK_EXCHANGE_API_INITED:
956 return
957 _SMV_DLPACK_EXCHANGE_API.header.version.major = DLPACK_MAJOR_VERSION
958 _SMV_DLPACK_EXCHANGE_API.header.version.minor = DLPACK_MINOR_VERSION
959 _SMV_DLPACK_EXCHANGE_API.header.prev_api = NULL
960 _SMV_DLPACK_EXCHANGE_API.managed_tensor_allocator = _smv_managed_tensor_allocator
961 _SMV_DLPACK_EXCHANGE_API.managed_tensor_from_py_object_no_sync = _smv_managed_tensor_from_py_object_no_sync
962 _SMV_DLPACK_EXCHANGE_API.managed_tensor_to_py_object_no_sync = _smv_managed_tensor_to_py_object_no_sync
963 _SMV_DLPACK_EXCHANGE_API.dltensor_from_py_object_no_sync = _smv_dltensor_from_py_object_no_sync
964 _SMV_DLPACK_EXCHANGE_API.current_work_stream = _smv_current_work_stream
965 _SMV_DLPACK_EXCHANGE_API_INITED = True
968_init_smv_dlpack_exchange_api()
969# cdef classes are immutable types in Cython 3, so inject these attributes
970# directly into the type dict.
971(<dict>(<PyTypeObject*>StridedMemoryView).tp_dict)["__dlpack_c_exchange_api__"] = _SMV_DLPACK_EXCHANGE_API_CAPSULE
972(<dict>(<PyTypeObject*>StridedMemoryView).tp_dict)["__c_dlpack_exchange_api__"] = _SMV_DLPACK_EXCHANGE_API_CAPSULE
973PyType_Modified(<PyTypeObject*>StridedMemoryView)
976cdef str get_simple_repr(obj):
977 # TODO: better handling in np.dtype objects
978 cdef object obj_class
979 cdef str obj_repr
980 if isinstance(obj, type): 1l
981 obj_class = obj
982 else:
983 obj_class = obj.__class__ 1l
984 if obj_class.__module__ in (None, "builtins"): 1l
985 obj_repr = obj_class.__name__
986 else:
987 obj_repr = f"{obj_class.__module__}.{obj_class.__name__}" 1l
988 return obj_repr 1l
992cdef bint check_has_dlpack(obj) except*:
993 cdef bint has_dlpack
994 if hasattr(obj, "__dlpack__") and hasattr(obj, "__dlpack_device__"): 1FDEJxmnopqstuvwyzABCcbHIalGk
995 has_dlpack = True 1FDEJxmnopqstuvwyzABCcbHIalGk
996 elif hasattr(obj, "__cuda_array_interface__"):
997 has_dlpack = False
998 else:
999 raise RuntimeError(
1000 "the input object does not support any data exchange protocol")
1001 return has_dlpack 1FDEJxmnopqstuvwyzABCcbHIalGk
1004cdef class _StridedMemoryViewProxy:
1005 cdef readonly:
1006 object obj
1007 bint has_dlpack
1009 def __init__(self, obj):
1010 self.obj = obj 1mnopq
1011 self.has_dlpack = check_has_dlpack(obj) 1mnopq
1013 cpdef StridedMemoryView view(self, stream_ptr=None):
1014 if self.has_dlpack: 1mnopq
1015 return StridedMemoryView.from_dlpack(self.obj, stream_ptr) 1mnopq
1016 else:
1017 return StridedMemoryView.from_cuda_array_interface(self.obj, stream_ptr)
1020cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
1021 cdef int dldevice, device_id
1022 cdef bint is_device_accessible, is_readonly
1023 is_device_accessible = False 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1024 dldevice, device_id = obj.__dlpack_device__() 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1025 if dldevice == _kDLCPU: 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1026 assert device_id == 0 1EJxmnopqstuvwyzABCcbHIalGkefghij
1027 device_id = -1 1EJxmnopqstuvwyzABCcbHIalGkefghij
1028 if stream_ptr is None: 1EJxmnopqstuvwyzABCcbHIalGkefghij
1029 raise BufferError("stream=None is ambiguous with view()")
1030 elif stream_ptr == -1: 1EJxmnopqstuvwyzABCcbHIalGkefghij
1031 stream_ptr = None 1EJxmnopqstuvwyzABCcbHIalGkefghij
1032 elif dldevice == _kDLCUDA:
1033 assert device_id >= 0
1034 is_device_accessible = True
1035 # no need to check other stream values, it's a pass-through
1036 if stream_ptr is None:
1037 raise BufferError("stream=None is ambiguous with view()")
1038 elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
1039 is_device_accessible = True 1FD
1040 # just do a pass-through without any checks, as pinned/managed memory can be
1041 # accessed on both host and device
1042 else:
1043 raise BufferError("device not supported")
1045 cdef object capsule
1046 try: 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1047 capsule = obj.__dlpack__( 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1048 stream=int(stream_ptr) if stream_ptr else None, 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1049 max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION)) 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1050 except TypeError:
1051 capsule = obj.__dlpack__(
1052 stream=int(stream_ptr) if stream_ptr else None)
1054 cdef void* data = NULL 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1055 cdef DLTensor* dl_tensor
1056 cdef DLManagedTensorVersioned* dlm_tensor_ver
1057 cdef DLManagedTensor* dlm_tensor
1058 cdef const char *used_name
1059 if cpython.PyCapsule_IsValid( 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1060 capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
1061 data = cpython.PyCapsule_GetPointer( 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1062 capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
1063 dlm_tensor_ver = <DLManagedTensorVersioned*>data 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1064 dl_tensor = &dlm_tensor_ver.dl_tensor 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1065 is_readonly = bool((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0) 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1066 used_name = DLPACK_VERSIONED_TENSOR_USED_NAME 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1067 elif cpython.PyCapsule_IsValid(
1068 capsule, DLPACK_TENSOR_UNUSED_NAME):
1069 data = cpython.PyCapsule_GetPointer(
1070 capsule, DLPACK_TENSOR_UNUSED_NAME)
1071 dlm_tensor = <DLManagedTensor*>data
1072 dl_tensor = &dlm_tensor.dl_tensor
1073 is_readonly = False
1074 used_name = DLPACK_TENSOR_USED_NAME
1075 else:
1076 assert False
1078 cpython.PyCapsule_SetName(capsule, used_name) 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1080 cdef StridedMemoryView buf = StridedMemoryView() if view is None else view 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1081 buf.dl_tensor = dl_tensor 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1082 buf.metadata = capsule 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1083 buf.ptr = <intptr_t>(dl_tensor.data) 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1084 buf.device_id = device_id 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1085 buf.is_device_accessible = is_device_accessible 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1086 buf.readonly = is_readonly 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1087 buf.exporting_obj = obj 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1089 return buf 1FDEJxmnopqstuvwyzABCcbHIalGkefghij
1092@functools.lru_cache
1093def _typestr2dtype(str typestr):
1094 return numpy.dtype(typestr) 1KLMNOPQRSTU
1097@functools.lru_cache
1098def _typestr2itemsize(str typestr):
1099 return _typestr2dtype(typestr).itemsize 1KLMNOPQRSTU
1102cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
1103 cdef int bits = dtype.bits 1alefghij
1104 if dtype.lanes != 1: 1alefghij
1105 # TODO: return a NumPy structured dtype?
1106 raise NotImplementedError(
1107 f'vector dtypes (lanes={dtype.lanes}) is not supported')
1108 if dtype.code == kDLUInt: 1alefghij
1109 if bits == 8:
1110 np_dtype = numpy.uint8
1111 elif bits == 16:
1112 np_dtype = numpy.uint16
1113 elif bits == 32:
1114 np_dtype = numpy.uint32
1115 elif bits == 64:
1116 np_dtype = numpy.uint64
1117 else:
1118 raise TypeError('uint{} is not supported.'.format(bits))
1119 elif dtype.code == kDLInt:
1120 if bits == 8: 1alefghij
1121 np_dtype = numpy.int8
1122 elif bits == 16:
1123 np_dtype = numpy.int16
1124 elif bits == 32:
1125 np_dtype = numpy.int32 1alefghij
1126 elif bits == 64:
1127 np_dtype = numpy.int64
1128 else:
1129 raise TypeError('int{} is not supported.'.format(bits))
1130 elif dtype.code == kDLFloat:
1131 if bits == 16:
1132 np_dtype = numpy.float16
1133 elif bits == 32:
1134 np_dtype = numpy.float32
1135 elif bits == 64:
1136 np_dtype = numpy.float64
1137 else:
1138 raise TypeError('float{} is not supported.'.format(bits))
1139 elif dtype.code == kDLComplex:
1140 # TODO(leofang): support complex32
1141 if bits == 64:
1142 np_dtype = numpy.complex64
1143 elif bits == 128:
1144 np_dtype = numpy.complex128
1145 else:
1146 raise TypeError('complex{} is not supported.'.format(bits))
1147 elif dtype.code == kDLBool:
1148 if bits == 8:
1149 np_dtype = numpy.bool_
1150 else:
1151 raise TypeError(f'{bits}-bit bool is not supported')
1152 elif dtype.code == kDLBfloat:
1153 if bfloat16 is not None:
1154 np_dtype = numpy.dtype("bfloat16")
1155 else:
1156 raise NotImplementedError(
1157 'Support for bfloat16 within cuda-core requires `ml_dtypes`'
1158 'to be installed.'
1159 )
1160 else:
1161 raise TypeError('Unsupported dtype. dtype code: {}'.format(dtype.code))
1163 # We want the dtype object not just the type object
1164 return numpy.dtype(np_dtype) 1alefghij
1167cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
1168 cdef dict cai_data = obj.__cuda_array_interface__ 2sbrbub9 hbU
1169 if cai_data["version"] < 3: 2sbrbub9 hbU
1170 raise BufferError("only CUDA Array Interface v3 or above is supported") 2ub
1171 if cai_data.get("mask") is not None: 2sbrb9 hbU
1172 raise BufferError("mask is not supported") 2sb
1173 if stream_ptr is None: 2rb9 hbU
1174 raise BufferError("stream=None is ambiguous with view()") 2rb
1176 cdef StridedMemoryView buf = StridedMemoryView() if view is None else view 29 hbU
1177 buf.exporting_obj = obj 29 hbU
1178 buf.metadata = cai_data 29 hbU
1179 buf.dl_tensor = NULL 29 hbU
1180 # Validate shape/strides/typestr eagerly so constructor paths fail fast.
1181 buf.get_layout() 29 hbU
1182 buf.ptr, buf.readonly = cai_data["data"] 19U
1183 buf.is_device_accessible = True 19U
1184 if buf.ptr != 0: 19U
1185 buf.device_id = handle_return(
1186 driver.cuPointerGetAttribute(
1187 driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
1188 buf.ptr))
1189 else:
1190 buf.device_id = handle_return(driver.cuCtxGetDevice()) 19U
1192 cdef intptr_t producer_s, consumer_s
1193 cdef EventHandle h_event
1194 stream_ptr = int(stream_ptr) 19U
1195 if stream_ptr != -1: 19U
1196 stream = cai_data.get("stream")
1197 if stream is not None:
1198 producer_s = <intptr_t>(stream)
1199 consumer_s = <intptr_t>(stream_ptr)
1200 assert producer_s > 0
1201 # establish stream order
1202 if producer_s != consumer_s:
1203 with nogil:
1204 h_event = create_event_handle_noctx(cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING)
1205 HANDLE_RETURN(cydriver.cuEventRecord(
1206 as_cu(h_event), <cydriver.CUstream>producer_s))
1207 HANDLE_RETURN(cydriver.cuStreamWaitEvent(
1208 <cydriver.CUstream>consumer_s, as_cu(h_event), 0))
1209 elif _is_torch_tensor(obj):
1210 # PyTorch's __cuda_array_interface__ reports version 2 and
1211 # omits the "stream" field, so the standard CAI sync path
1212 # above is a no-op for torch tensors. This is unsafe: the
1213 # consumer has no guarantee that the producer's work is
1214 # visible. We fix this by querying PyTorch's current CUDA
1215 # stream via the AOTI stable C ABI and performing the same
1216 # event-based stream ordering.
1217 _get_tensor_bridge().sync_torch_stream(
1218 buf.device_id, <intptr_t>(stream_ptr))
1220 return buf 19U
1223cpdef StridedMemoryView view_as_array_interface(obj, view=None):
1224 cdef dict data = obj.__array_interface__ 2tbvbK L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
1225 if data["version"] < 3: 2tbvbK L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
1226 raise BufferError("only NumPy Array Interface v3 or above is supported") 2vb
1227 if data.get("mask") is not None: 2tbK L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
1228 raise BufferError("mask is not supported") 2tb
1230 cdef StridedMemoryView buf = StridedMemoryView() if view is None else view 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
1231 buf.exporting_obj = obj 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
1232 buf.metadata = data 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
1233 buf.dl_tensor = NULL 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
1234 # Validate shape/strides/typestr eagerly so constructor paths fail fast.
1235 buf.get_layout() 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb
1236 buf.ptr, buf.readonly = data["data"] 1KLVMWNXOPQRYZ0123S45678T
1237 buf.is_device_accessible = False 1KLVMWNXOPQRYZ0123S45678T
1238 buf.device_id = handle_return(driver.cuCtxGetDevice()) 1KLVMWNXOPQRYZ0123S45678T
1239 return buf 1KLVMWNXOPQRYZ0123S45678T
1242def args_viewable_as_strided_memory(tuple arg_indices):
1243 """
1244 Decorator to create proxy objects to :obj:`StridedMemoryView` for the
1245 specified positional arguments.
1247 This allows array/tensor attributes to be accessed inside the function
1248 implementation, while keeping the function body array-library-agnostic (if
1249 desired).
1251 Inside the decorated function, the specified arguments become instances
1252 of an (undocumented) proxy type, regardless of its original source. A
1253 :obj:`StridedMemoryView` instance can be obtained by passing the (consumer)
1254 stream pointer (as a Python `int`) to the proxies's ``view()`` method. For
1255 example:
1257 .. code-block:: python
1259 @args_viewable_as_strided_memory((1,))
1260 def my_func(arg0, arg1, arg2, stream: Stream):
1261 # arg1 can be any object supporting DLPack or CUDA Array Interface
1262 view = arg1.view(stream.handle)
1263 assert isinstance(view, StridedMemoryView)
1264 ...
1266 Parameters
1267 ----------
1268 arg_indices : tuple
1269 The indices of the target positional arguments.
1270 """
1271 def wrapped_func_with_indices(func): 1mnopq
1272 @functools.wraps(func) 1mnopq
1273 def wrapped_func(*args, **kwargs):
1274 args = list(args) 1mnopq
1275 cdef int idx
1276 for idx in arg_indices: 1mnopq
1277 args[idx] = _StridedMemoryViewProxy(args[idx]) 1mnopq
1278 return func(*args, **kwargs) 1mnopq
1279 return wrapped_func 1mnopq
1280 return wrapped_func_with_indices 1mnopq
1283cdef inline _StridedLayout layout_from_dlpack(DLTensor* dl_tensor):
1284 cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout) 1xmnopqstuvwyzABCcbalkefghij
1285 cdef int nbits = dl_tensor.dtype.bits * dl_tensor.dtype.lanes 1xmnopqstuvwyzABCcbalkefghij
1286 cdef int itemsize = nbits >> 3 1xmnopqstuvwyzABCcbalkefghij
1287 if (itemsize << 3) != nbits: 1xmnopqstuvwyzABCcbalkefghij
1288 raise ValueError("dl_tensor.dtype.bits must be a multiple of 8")
1289 layout.init_from_ptr(dl_tensor.ndim, dl_tensor.shape, dl_tensor.strides, itemsize) 1xmnopqstuvwyzABCcbalkefghij
1290 return layout 1xmnopqstuvwyzABCcbalkefghij
1293cdef _StridedLayout layout_from_cai(object metadata):
1294 cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout) 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU
1295 cdef object shape = metadata["shape"] 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU
1296 cdef object strides = metadata.get("strides") 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU
1297 cdef int itemsize = _typestr2itemsize(metadata["typestr"]) 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU
1298 layout.init_from_tuple(shape, strides, itemsize, True) 2K L V M W N X O P Q R Y Z 0 1 2 3 S 4 5 6 7 8 T jb9 hbU
1299 return layout 1KLVMWNXOPQRYZ0123S45678T9U
1302cdef inline intptr_t get_data_ptr(object buffer, _StridedLayout layout) except? 0:
1303 return <intptr_t>(int(buffer.handle)) + layout.get_slice_offset_in_bytes() 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1306cdef inline int view_buffer_strided(
1307 StridedMemoryView view,
1308 object buffer,
1309 _StridedLayout layout,
1310 object dtype,
1311 bint is_readonly,
1312) except -1:
1313 if dtype is not None: 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmb! # kbibr k e f g h i j
1314 dtype = numpy.dtype(dtype) 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmb! # kbibk e f g h i j
1315 if dtype.itemsize != layout.itemsize: 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmb! # kbibk e f g h i j
1316 raise ValueError(
1317 f"The dtype's itemsize ({dtype.itemsize}) does not match the layout's "
1318 f"itemsize ({layout.itemsize})."
1319 )
1320 # Check the layout's offset range [min_offset, max_offset] fits
1321 # within the [0, buffer.size - 1] range.
1322 # The required_size_in_bytes fails if min_offset < 0.
1323 # NB. For external memory, both positive and negative offsets can be valid,
1324 # but for a proper check we'd need to know both size and data offset,
1325 # while neither is reported by the packages.
1326 cdef bint is_allocated = buffer.memory_resource is not None 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmb! # kbibr k e f g h i j
1327 if is_allocated and buffer.size < layout.get_required_size_in_bytes(): 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgbmb! # kbibr k e f g h i j
1328 raise ValueError( 2kb
1329 f"Buffer size is too small for the layout. " 2kb
1330 f"Expected at least {layout.get_required_size_in_bytes()} bytes, " 2kb
1331 f"got {buffer.size} bytes." 2kb
1332 )
1333 # set the public attributes
1334 view.ptr = get_data_ptr(buffer, layout) 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1335 view.device_id = buffer.device_id 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1336 view.is_device_accessible = buffer.is_device_accessible 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1337 view.readonly = is_readonly 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1338 view.exporting_obj = view._buffer = buffer 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1339 # no dlpack/cai metadata
1340 view.dl_tensor = NULL 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1341 view.metadata = None 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1342 # we get the layout from the caller
1343 view._layout = layout 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1344 view._dtype = dtype 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j
1345 return 0 2c b $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcbdbebfbgb! # ibr k e f g h i j