Coverage for cuda / core / _memoryview.pyx: 59.97%
637 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7from ._dlpack cimport *
8from libc.stdint cimport intptr_t
9from cuda.core._layout cimport _StridedLayout, get_strides_ptr
10from cuda.core._stream import Stream
12import functools
13import warnings
15import numpy
17from cuda.bindings cimport cydriver
18from cuda.core._resource_handles cimport (
19 EventHandle,
20 create_event_handle_noctx,
21 as_cu,
22)
24from cuda.core._utils.cuda_utils import handle_return, driver
25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
28from cuda.core._memory import Buffer
31try:
32 from ml_dtypes import bfloat16
33except ImportError:
34 bfloat16 = None
36# TODO(leofang): support NumPy structured dtypes
39cdef extern from "Python.h":
40 ctypedef struct PyTypeObject:
41 void* tp_dict
42 void PyType_Modified(PyTypeObject*)
45cdef DLPackExchangeAPI _SMV_DLPACK_EXCHANGE_API
46cdef bint _SMV_DLPACK_EXCHANGE_API_INITED = False
47_SMV_DLPACK_EXCHANGE_API_CAPSULE = cpython.PyCapsule_New(
48 <void*>&_SMV_DLPACK_EXCHANGE_API,
49 b"dlpack_exchange_api",
50 NULL,
51)
54cdef class StridedMemoryView:
55 """A class holding metadata of a strided dense array/tensor.
57 A :obj:`StridedMemoryView` instance can be created in three ways:
59 1. Using the :obj:`args_viewable_as_strided_memory` decorator (recommended)
60 2. Explicit construction relying on DLPack or CUDA Array Interface, see below.
61 3. From :obj:`~_memory.Buffer` and shape and size tuples (see
62 :meth:`from_buffer` classmethod)
64 ``StridedMemoryView(obj, stream_ptr)`` can be used to create a view from
65 objects supporting either DLPack (up to v1.0) or CUDA Array Interface
66 (CAI) v3. When wrapping an arbitrary object it will try the DLPack protocol
67 first, then the CAI protocol. A :obj:`BufferError` is raised if neither is
68 supported.
70 Since either way would take a consumer stream, for DLPack it is passed to
71 ``obj.__dlpack__()`` as-is (except for :obj:`None`, see below); for CAI, a
72 stream order will be established between the consumer stream and the
73 producer stream (from ``obj.__cuda_array_interface__()["stream"]``), as if
74 ``cudaStreamWaitEvent`` is called by this method.
76 To opt-out of the stream ordering operation in either DLPack or CAI,
77 please pass ``stream_ptr=-1``. Note that this deviates (on purpose)
78 from the semantics of ``obj.__dlpack__(stream=None, ...)`` since ``cuda.core``
79 does not encourage using the (legacy) default/null stream, but is
80 consistent with the CAI's semantics. For DLPack, ``stream=-1`` will be
81 internally passed to ``obj.__dlpack__()`` instead.
83 Parameters
84 ----------
85 obj : Any
86 Any objects that supports either DLPack (up to v1.0) or CUDA Array
87 Interface (v3).
88 stream_ptr: int
89 The pointer address (as Python `int`) to the **consumer** stream.
90 Stream ordering will be properly established unless ``-1`` is passed.
93 Attributes
94 -----------
95 ptr : int
96 Pointer to the tensor buffer (as a Python `int`).
97 device_id : int
98 The device ID for where the tensor is located. It is -1 for CPU tensors
99 (meaning those only accessible from the host).
100 is_device_accessible : bool
101 Whether the tensor data can be accessed on the GPU.
102 readonly: bool
103 Whether the tensor data can be modified in place.
104 exporting_obj : Any
105 A reference to the original tensor object that is being viewed.
106 If the view is created with :meth:`from_buffer`,
107 it will be the Buffer instance passed to the method.
109 """
110 cdef readonly:
111 intptr_t ptr
112 int device_id
113 bint is_device_accessible
114 bint readonly
115 object exporting_obj
117 cdef:
118 # If using dlpack, this is a strong reference to the result of
119 # obj.__dlpack__() so we can lazily create shape and strides from
120 # it later. If using CAI, this is a reference to the source
121 # `__cuda_array_interface__` object.
122 object metadata
124 # The tensor object if has obj has __dlpack__, otherwise must be NULL
125 DLTensor *dl_tensor
127 # Memoized properties
128 # Either lazily inferred from dl_tensor/metadata,
129 # or explicitly provided if created with from_buffer().
130 _StridedLayout _layout
131 # Either exporting_obj if it is a Buffer, otherwise a Buffer instance
132 # with owner set to the exporting object.
133 object _buffer
134 # Either lazily inferred from dl_tensor/metadata,
135 # or explicitly provided if created with from_buffer().
136 # In the latter case, it can be None.
137 object _dtype
139 def __init__(self, obj: object = None, stream_ptr: int | None = None) -> None:
140 cdef str clsname = self.__class__.__name__ 1nopqr
141 if obj is not None: 1nopqr
142 # populate self's attributes
143 if check_has_dlpack(obj): 1nopqr
144 warnings.warn( 1nopqr
145 f"Constructing a {clsname} directly from a DLPack-supporting object is deprecated; " 1nopqr
146 "Use `StridedMemoryView.from_dlpack` or `StridedMemoryView.from_any_interface` instead.",
147 DeprecationWarning, 1nopqr
148 stacklevel=2,
149 )
150 view_as_dlpack(obj, stream_ptr, self) 1nopqr
151 else:
152 warnings.warn(
153 f"Constructing a {clsname} directly from a CUDA-array-interface-supporting object is deprecated; "
154 "Use `StridedMemoryView.from_cuda_array_interface` or `StridedMemoryView.from_any_interface` instead.",
155 DeprecationWarning,
156 stacklevel=2,
157 )
158 view_as_cai(obj, stream_ptr, self)
159 else:
160 warnings.warn(
161 f"Constructing an empty {clsname} is deprecated; "
162 "use one of the classmethods `from_dlpack`, `from_cuda_array_interface` or `from_any_interface` "
163 "to construct a StridedMemoryView from an object",
164 DeprecationWarning,
165 stacklevel=2,
166 )
168 @classmethod
169 def from_dlpack(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
170 """Create a view from an object supporting the `DLPack <https://dmlc.github.io/dlpack/latest/>`_ protocol.
172 Parameters
173 ----------
174 obj : object
175 An object implementing the `DLPack <https://dmlc.github.io/dlpack/latest/>`_ protocol
176 (via ``__dlpack__``).
177 stream_ptr : int, optional
178 Stream pointer for synchronization. If ``None``, no synchronization is performed.
179 """
180 cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) 1zyijklmtuvwxacdefgh
181 view_as_dlpack(obj, stream_ptr, buf) 1zyijklmtuvwxacdefgh
182 return buf 1zyijklmtuvwxacdefgh
184 @classmethod
185 def from_cuda_array_interface(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
186 """Create a view from an object supporting the `__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_ protocol.
188 Parameters
189 ----------
190 obj : object
191 An object implementing the `__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_ protocol.
192 stream_ptr : int, optional
193 Stream pointer for synchronization. If ``None``, no synchronization is performed.
194 """
195 cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) 11}0
196 view_as_cai(obj, stream_ptr, buf) 11}0
197 return buf 110
199 @classmethod
200 def from_array_interface(cls, obj: object) -> StridedMemoryView:
201 """Create a view from an object supporting the `__array_interface__ <https://numpy.org/doc/stable/reference/arrays.interface.html>`_ protocol.
203 Parameters
204 ----------
205 obj : object
206 An object implementing the `__array_interface__ <https://numpy.org/doc/stable/reference/arrays.interface.html>`_ protocol (e.g., a numpy array).
207 """
208 cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) 1ABCDOEPFGHIJQRKSTUVWXYZL~
209 view_as_array_interface(obj, buf) 1ABCDOEPFGHIJQRKSTUVWXYZL~
210 return buf 1ABCDOEPFGHIJQRKSTUVWXYZL
212 @classmethod
213 def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView:
214 """Create a view by automatically selecting the best available protocol.
216 Tries `DLPack <https://dmlc.github.io/dlpack/latest/>`_ first, then falls back to
217 `__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_.
219 Parameters
220 ----------
221 obj : object
222 An object implementing `DLPack <https://dmlc.github.io/dlpack/latest/>`_ or
223 `__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_.
224 stream_ptr : int, optional
225 Stream pointer for synchronization. If ``None``, no synchronization is performed.
226 """
227 if check_has_dlpack(obj): 1zytuvwxa
228 return cls.from_dlpack(obj, stream_ptr) 1zytuvwxa
229 return cls.from_cuda_array_interface(obj, stream_ptr)
231 @classmethod
232 def from_buffer(
233 cls,
234 buffer : Buffer,
235 shape : tuple[int, ...],
236 strides : tuple[int, ...] | None = None,
237 *,
238 itemsize : int | None = None,
239 dtype : numpy.dtype | None = None,
240 is_readonly : bool = False
241 ) -> StridedMemoryView:
242 """
243 Creates a :obj:`StridedMemoryView` instance from a :obj:`~_memory.Buffer` and shape and strides tuples.
244 The Buffer can be either allocation coming from a :obj:`MemoryResource` or an external allocation
245 wrapped in a :obj:`~_memory.Buffer` object with ``Buffer.from_handle(ptr, size, owner=...)``.
247 .. caution::
248 When creating a :obj:`StridedMemoryView` from a :obj:`~_memory.Buffer`,
249 no synchronization is performed. It is the user's responsibility to ensure
250 the data in ``buffer`` is properly synchronized when consuming the view.
252 Parameters
253 ----------
254 buffer : :obj:`~_memory.Buffer`
255 The buffer to create the view from.
256 shape : :obj:`tuple`
257 The layout describing the shape, strides and itemsize of the elements in
258 the buffer.
259 strides : :obj:`tuple`
260 The layout describing the shape, strides and itemsize of the elements in
261 the buffer.
262 dtype : :obj:`numpy.dtype`
263 Optional dtype.
264 If specified, the dtype's itemsize must match the layout's itemsize.
265 is_readonly : bool, optional
266 Whether the mark the view as readonly.
267 """
268 cdef StridedMemoryView view = StridedMemoryView.__new__(cls) 2db2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbebcbM N ab| s
269 if itemsize is None and dtype is None: 2db2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbebcbM N ab| s
270 raise ValueError("Either itemsize or dtype must be specified") 2eb
271 if itemsize is not None and dtype is not None and itemsize != dtype.itemsize: 2db2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbcbM N ab| s
272 raise ValueError( 2cb
273 f"itemsize ({itemsize}) does not match dtype.itemsize ({dtype.itemsize})" 2cb
274 )
275 # (itemsize is None XOR dtype is None) OR they are equal
276 view_buffer_strided( 22 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbM N ab| s
277 view,
278 buffer,
279 _StridedLayout(shape=shape, strides=strides, itemsize=getattr(dtype, "itemsize", itemsize)), 2db2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbM N ab| s
280 dtype,
281 is_readonly,
282 )
283 return view 123456789!#$%'()*+,-./:;=?@[]^_`{MN|s
285 def __dealloc__(self):
286 if self.dl_tensor == NULL: 2dbz y i j k l m t u v w x n o p q r A B C D O E P F G H I J Q R K S T U V W X Y Z L ~ 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbebcbM N ab| 1 } 0 a s c d e f g h
287 return 2dbA B C D O E P F G H I J Q R K S T U V W X Y Z L ~ 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbebcbM N ab| 1 } 0 s c d e f g h
289 if cpython.PyCapsule_IsValid( 1zyijklmtuvwxnopqracdefgh
290 self.metadata, DLPACK_VERSIONED_TENSOR_USED_NAME): 1zyijklmtuvwxnopqracdefgh
291 data = cpython.PyCapsule_GetPointer( 1zyijklmtuvwxnopqracdefgh
292 self.metadata, DLPACK_VERSIONED_TENSOR_USED_NAME) 1bzyijklmtuvwxnopqracdefgh
293 dlm_tensor_ver = <DLManagedTensorVersioned*>data 1zyijklmtuvwxnopqracdefgh
294 dlm_tensor_ver.deleter(dlm_tensor_ver) 1zyijklmtuvwxnopqracdefgh
295 elif cpython.PyCapsule_IsValid(
296 self.metadata, DLPACK_TENSOR_USED_NAME):
297 data = cpython.PyCapsule_GetPointer(
298 self.metadata, DLPACK_TENSOR_USED_NAME)
299 dlm_tensor = <DLManagedTensor*>data
300 dlm_tensor.deleter(dlm_tensor)
302 def view(
303 self, layout : _StridedLayout | None = None, dtype : numpy.dtype | None = None
304 ) -> StridedMemoryView:
305 """
306 Creates a new view with adjusted layout and dtype.
307 Same as calling :meth:`from_buffer` with the current buffer.
308 """
309 cdef StridedMemoryView view = StridedMemoryView.__new__(self.__class__) 1MNcdefgh
310 if layout is None and dtype is None: 1MNcdefgh
311 return self
312 if layout is None: 1MNcdefgh
313 layout = self.get_layout()
314 if dtype is None: 1MNcdefgh
315 dtype = self.get_dtype() 1MNcdefgh
316 view_buffer_strided(view, self.get_buffer(), layout, dtype, self.readonly) 1MNcdefgh
317 return view 1MNcdefgh
319 def copy_from(
320 self, other : StridedMemoryView, stream : Stream,
321 allocator = None,
322 blocking : bool | None = None,
323 ):
324 """
325 Copies the data from the other view into this view.
327 The copy can be performed between following memory spaces:
328 host-to-device, device-to-host, device-to-device (on the same device).
330 Parameters
331 ----------
332 other : StridedMemoryView
333 The view to copy data from.
334 stream : Stream | None, optional
335 The stream to schedule the copy on.
336 allocator : MemoryResource | None, optional
337 If temporary buffers are needed, the specified memory resources
338 will be used to allocate the memory. If not specified, default
339 resources will be used.
340 blocking : bool | None, optional
341 Whether the call should block until the copy is complete.
342 * ``True``: the ``stream`` is synchronized with the host at the end of the call,
343 blocking until the copy is complete.
344 * ``False``: if possible, the call returns immediately once the copy is scheduled.
345 However, in some cases of host-to-device or device-to-host copies, the call may
346 still synchronize with the host if necessary.
347 * ``None`` (default):
348 * for device-to-device, it defaults to ``False`` (non-blocking),
349 * for host-to-device or device-to-host, it defaults to ``True`` (blocking).
350 """
351 raise NotImplementedError("Sorry, not supported: copy_from")
353 def copy_to(
354 self, other : StridedMemoryView, stream : Stream | None = None,
355 allocator = None,
356 blocking : bool | None = None,
357 ):
358 """
359 Copies the data from this view into the ``other`` view.
361 For details, see :meth:`copy_from`.
362 """
363 raise NotImplementedError("Sorry, not supported: copy_to")
365 def __dlpack__(
366 self,
367 *,
368 stream: int | None = None,
369 max_version: tuple[int, int] | None = None,
370 dl_device: tuple[int, int] | None = None,
371 copy: bool | None = None,
372 ):
373 # Similar to Buffer.__dlpack__: no implicit synchronization is performed.
374 if dl_device is not None: 1as
375 raise BufferError("Sorry, not supported: dl_device other than None")
376 if copy is True: 1as
377 raise BufferError("Sorry, not supported: copy=True")
379 cdef bint versioned
380 if max_version is None: 1as
381 versioned = False 1s
382 else:
383 if not isinstance(max_version, tuple) or len(max_version) != 2: 1a
384 raise BufferError(f"Expected max_version tuple[int, int], got {max_version}")
385 versioned = max_version >= (1, 0) 1a
387 # NOTE: stream is accepted for protocol compatibility but not used.
388 cdef object capsule = _smv_make_py_capsule(self, versioned) 1as
389 return capsule 1a
391 def __dlpack_device__(self) -> tuple[int, int]:
392 cdef _DLDeviceType device_type
393 cdef int32_t device_id
394 _smv_get_dl_device(self, &device_type, &device_id) 1a
395 return (<int>device_type, int(device_id)) 1a
397 @property
398 def _layout(self) -> _StridedLayout:
399 """
400 The layout of the tensor. For StridedMemoryView created from DLPack or CAI,
401 the layout is inferred from the tensor object's metadata.
402 """
403 return self.get_layout() 123456789!#$%'()*+,-./:;=?@[]^_`{MNcdefgh
405 @property
406 def size(self) -> int:
407 return self.get_layout().get_volume() 1ijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZL1
409 @property
410 def shape(self) -> tuple[int, ...]:
411 """
412 Shape of the tensor.
413 """
414 return self.get_layout().get_shape_tuple() 1yijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZL23456789!#$%'()*+,-./:;=?@[]^_`{MN10cdefgh
416 @property
417 def strides(self) -> tuple[int, ...] | None:
418 """
419 Strides of the tensor (in **counts**, not bytes).
420 """
421 return self.get_layout().get_strides_tuple() 1yijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZL23456789!#$%'()*+,-./:;=?@[]^_`{10
423 @property
424 def dtype(self) -> numpy.dtype | None:
425 """
426 Data type of the tensor.
428 Supports standard NumPy dtypes as well as narrow data types (e.g., ``bfloat16``)
429 when the optional `ml_dtypes <https://github.com/jax-ml/ml_dtypes>`_ package is
430 installed. If ``ml_dtypes`` is not available and such a tensor is encountered,
431 a :obj:`NotImplementedError` will be raised.
432 """
433 return self.get_dtype() 1ijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZL23456789!#$%'()*+,-./:;=?@[]^_`{|cdefgh
435 def __repr__(self):
436 return (f"StridedMemoryView(ptr={self.ptr},\n"
437 + f" shape={self.shape},\n"
438 + f" strides={self.strides},\n"
439 + f" itemsize={self._layout.itemsize},\n"
440 + f" dtype={get_simple_repr(self.dtype)},\n"
441 + f" device_id={self.device_id},\n"
442 + f" is_device_accessible={self.is_device_accessible},\n"
443 + f" readonly={self.readonly},\n"
444 + f" exporting_obj={get_simple_repr(self.exporting_obj)})")
446 cdef inline _StridedLayout get_layout(self):
447 if self._layout is None: 1yijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZL~23456789!#$%'()*+,-./:;=?@[]^_`{MN1}0ascdefgh
448 if self.dl_tensor: 1yijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZL~1}0acdefgh
449 self._layout = layout_from_dlpack(self.dl_tensor) 1yijklmtuvwxnopqracdefgh
450 elif self.metadata is not None: 1ABCDOEPFGHIJQRKSTUVWXYZL~1}0
451 self._layout = layout_from_cai(self.metadata) 1ABCDOEPFGHIJQRKSTUVWXYZL~1}0
452 else:
453 raise ValueError("Cannot infer layout from the exporting object")
454 return self._layout 1yijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZL23456789!#$%'()*+,-./:;=?@[]^_`{MN10ascdefgh
456 cdef inline object get_buffer(self):
457 """
458 Returns Buffer instance with the underlying data.
459 If the SMV was created from a Buffer, it will return the same Buffer instance.
460 Otherwise, it will create a new instance with owner set to the exporting object.
461 """
462 if self._buffer is None: 1MNcdefgh
463 if isinstance(self.exporting_obj, Buffer): 1cdefgh
464 self._buffer = self.exporting_obj
465 else:
466 self._buffer = Buffer.from_handle(self.ptr, 0, owner=self.exporting_obj) 1cdefgh
467 return self._buffer 1MNcdefgh
469 cdef inline object get_dtype(self):
470 if self._dtype is None: 1ijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZL23456789!#$%'()*+,-./:;=?@[]^_`{MN|ascdefgh
471 if self.dl_tensor != NULL: 1ijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZLascdefgh
472 self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) 1ijklmtuvwxnopqracdefgh
473 elif self.metadata is not None: 1ABCDOEPFGHIJQRKSTUVWXYZLs
474 self._dtype = _typestr2dtype(self.metadata["typestr"]) 1ABCDOEPFGHIJQRKSTUVWXYZL
475 return self._dtype 1ijklmtuvwxnopqrABCDOEPFGHIJQRKSTUVWXYZL23456789!#$%'()*+,-./:;=?@[]^_`{MN|ascdefgh
478cdef void _smv_pycapsule_deleter(object capsule) noexcept:
479 cdef DLManagedTensor* dlm_tensor
480 cdef DLManagedTensorVersioned* dlm_tensor_ver
481 # Do not invoke the deleter on a used capsule.
482 if cpython.PyCapsule_IsValid(capsule, DLPACK_TENSOR_UNUSED_NAME): 1a
483 dlm_tensor = <DLManagedTensor*>(
484 cpython.PyCapsule_GetPointer(capsule, DLPACK_TENSOR_UNUSED_NAME)
485 )
486 if dlm_tensor.deleter:
487 dlm_tensor.deleter(dlm_tensor)
488 elif cpython.PyCapsule_IsValid(capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME): 1a
489 dlm_tensor_ver = <DLManagedTensorVersioned*>(
490 cpython.PyCapsule_GetPointer(capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
491 )
492 if dlm_tensor_ver.deleter:
493 dlm_tensor_ver.deleter(dlm_tensor_ver)
496cdef inline void _smv_release_export_resources(void* manager_ctx, int64_t* shape_ptr) noexcept with gil:
497 if shape_ptr: 1bas
498 stdlib.free(shape_ptr) 1a
499 if manager_ctx: 1as
500 cpython.Py_DECREF(<object>manager_ctx) 1as
503cdef void _smv_deleter(DLManagedTensor* tensor) noexcept with gil:
504 if tensor: 1s
505 _smv_release_export_resources(tensor.manager_ctx, tensor.dl_tensor.shape) 1s
506 tensor.manager_ctx = NULL 1s
507 stdlib.free(tensor) 1s
510cdef void _smv_versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
511 if tensor: 1as
512 _smv_release_export_resources(tensor.manager_ctx, tensor.dl_tensor.shape) 1a
513 tensor.manager_ctx = NULL 1a
514 stdlib.free(tensor) 1a
517cdef inline DLManagedTensorVersioned* _smv_allocate_dlm_tensor_versioned() except? NULL:
518 cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL 1a
519 dlm_tensor_ver = <DLManagedTensorVersioned*>stdlib.malloc(sizeof(DLManagedTensorVersioned)) 1a
520 if dlm_tensor_ver == NULL: 1a
521 raise MemoryError()
522 dlm_tensor_ver.dl_tensor.shape = NULL 1a
523 dlm_tensor_ver.manager_ctx = NULL 1ba
524 return dlm_tensor_ver 1a
527cdef inline DLManagedTensor* _smv_allocate_dlm_tensor() except? NULL:
528 cdef DLManagedTensor* dlm_tensor = NULL 1s
529 dlm_tensor = <DLManagedTensor*>stdlib.malloc(sizeof(DLManagedTensor)) 1s
530 if dlm_tensor == NULL: 1s
531 raise MemoryError()
532 dlm_tensor.dl_tensor.shape = NULL 1s
533 dlm_tensor.manager_ctx = NULL 1s
534 return dlm_tensor 1s
537cdef inline int _smv_dtype_numpy_to_dlpack(object dtype_obj, DLDataType* out_dtype) except -1:
538 cdef object np_dtype = numpy.dtype(dtype_obj) 1a
539 if np_dtype.fields is not None: 1a
540 raise BufferError("Structured dtypes are not supported for DLPack export")
541 if not np_dtype.isnative and np_dtype.byteorder not in ("=", "|"): 1ba
542 raise BufferError("Non-native-endian dtypes are not supported for DLPack export")
544 cdef str kind = np_dtype.kind 1a
545 cdef int bits = np_dtype.itemsize * 8 1a
546 cdef uint8_t code
547 if kind == "b": 1a
548 if bits != 8:
549 raise BufferError(f"Unsupported bool dtype itemsize: {np_dtype.itemsize}")
550 code = <uint8_t>kDLBool
551 elif kind == "i": 1a
552 if bits not in (8, 16, 32, 64): 1a
553 raise BufferError(f"Unsupported signed integer dtype: {np_dtype}")
554 code = <uint8_t>kDLInt 1a
555 elif kind == "u":
556 if bits not in (8, 16, 32, 64):
557 raise BufferError(f"Unsupported unsigned integer dtype: {np_dtype}")
558 code = <uint8_t>kDLUInt
559 elif kind == "f":
560 if bits not in (16, 32, 64):
561 raise BufferError(f"Unsupported floating dtype: {np_dtype}")
562 code = <uint8_t>kDLFloat
563 elif kind == "c":
564 if bits not in (64, 128):
565 raise BufferError(f"Unsupported complex dtype: {np_dtype}")
566 code = <uint8_t>kDLComplex
567 else:
568 raise BufferError(f"Unsupported dtype for DLPack export: {np_dtype}")
570 out_dtype.code = code 1a
571 out_dtype.bits = <uint8_t>bits 1a
572 out_dtype.lanes = <uint16_t>1 1a
573 return 0 1a
576cdef inline int _smv_get_dl_device(
577 StridedMemoryView view,
578 _DLDeviceType* out_device_type,
579 int32_t* out_device_id,
580) except -1:
581 cdef _DLDeviceType device_type
582 cdef int32_t device_id
583 cdef object buf
584 cdef bint d
585 cdef bint h
586 if view.dl_tensor != NULL: 1a
587 device_type = view.dl_tensor.device.device_type 1ba
588 if device_type == _kDLCUDA: 1a
589 device_id = view.dl_tensor.device.device_id
590 else:
591 # CPU, CUDAHost, and CUDAManaged use device_id=0 in DLPack.
592 device_id = 0 1a
593 elif view.is_device_accessible:
594 buf = view.get_buffer()
595 d = buf.is_device_accessible
596 h = buf.is_host_accessible
597 if d and (not h):
598 device_type = _kDLCUDA
599 device_id = buf.device_id
600 elif d and h:
601 # We do not currently differentiate pinned vs managed here.
602 device_type = _kDLCUDAHost
603 device_id = 0
604 elif (not d) and h:
605 device_type = _kDLCPU
606 device_id = 0
607 else:
608 raise BufferError("buffer is neither device-accessible nor host-accessible")
609 else:
610 device_type = _kDLCPU
611 device_id = 0
613 out_device_type[0] = device_type 1ba
614 out_device_id[0] = device_id 1a
615 return 0 1a
618cdef inline int _smv_setup_dl_tensor_common(
619 DLTensor* dl_tensor,
620 StridedMemoryView view,
621 _StridedLayout layout,
622) except -1:
623 cdef object dtype_obj = view.get_dtype() 1as
624 if dtype_obj is None: 1as
625 raise BufferError( 1s
626 "Cannot export StridedMemoryView via DLPack without dtype information; "
627 "create the view with dtype specified."
628 )
629 _smv_dtype_numpy_to_dlpack(dtype_obj, &dl_tensor.dtype) 1a
630 _smv_get_dl_device(view, &dl_tensor.device.device_type, &dl_tensor.device.device_id) 1a
632 cdef int ndim = layout.base.ndim 1ba
633 dl_tensor.ndim = ndim 1a
634 if layout.get_volume() == 0: 1a
635 dl_tensor.data = NULL
636 else:
637 dl_tensor.data = <void*><intptr_t>view.ptr 1a
638 dl_tensor.byte_offset = 0 1a
639 return 0 1a
642cdef inline int _smv_setup_dl_tensor(DLTensor* dl_tensor, StridedMemoryView view) except -1:
643 cdef _StridedLayout layout = view.get_layout() 1as
644 _smv_setup_dl_tensor_common(dl_tensor, view, layout) 1as
646 cdef int i
647 cdef int64_t* shape_strides = NULL 1a
648 cdef int64_t* strides_src = NULL 1ba
649 cdef int ndim = dl_tensor.ndim 1a
650 if ndim == 0: 1a
651 dl_tensor.shape = NULL
652 dl_tensor.strides = NULL
653 else:
654 # DLPack v1.2+ requires non-NULL strides for ndim != 0.
655 shape_strides = <int64_t*>stdlib.malloc(sizeof(int64_t) * 2 * ndim) 1a
656 if shape_strides == NULL: 1a
657 raise MemoryError()
658 try: 1a
659 strides_src = get_strides_ptr(layout.base) 1a
660 for i in range(ndim): 1a
661 shape_strides[i] = layout.base.shape[i] 1a
662 shape_strides[i + ndim] = strides_src[i] 1a
663 except Exception:
664 stdlib.free(shape_strides)
665 raise
666 dl_tensor.shape = shape_strides 1a
667 dl_tensor.strides = shape_strides + ndim 1a
668 return 0 1ba
671cdef inline int _smv_setup_dltensor_borrowed(DLTensor* dl_tensor, StridedMemoryView view) except -1:
672 cdef _StridedLayout layout = view.get_layout()
673 _smv_setup_dl_tensor_common(dl_tensor, view, layout)
675 if dl_tensor.ndim == 0:
676 dl_tensor.shape = NULL
677 dl_tensor.strides = NULL
678 else:
679 dl_tensor.shape = layout.base.shape
680 # For temporary/non-owning exchange we provide explicit strides.
681 dl_tensor.strides = get_strides_ptr(layout.base)
682 return 0
685cdef inline int _smv_fill_managed_tensor_versioned(
686 DLManagedTensorVersioned* dlm_tensor_ver,
687 StridedMemoryView view,
688) except -1:
689 cpython.Py_INCREF(view) 1a
690 dlm_tensor_ver.manager_ctx = <void*>view 1a
691 dlm_tensor_ver.deleter = _smv_versioned_deleter 1a
692 dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION 1a
693 dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION 1a
694 dlm_tensor_ver.flags = DLPACK_FLAG_BITMASK_READ_ONLY if view.readonly else 0 1a
695 _smv_setup_dl_tensor(&dlm_tensor_ver.dl_tensor, view) 1a
696 return 0 1a
699cdef inline int _smv_fill_managed_tensor(
700 DLManagedTensor* dlm_tensor,
701 StridedMemoryView view,
702) except -1:
703 cpython.Py_INCREF(view) 1s
704 dlm_tensor.manager_ctx = <void*>view 1s
705 dlm_tensor.deleter = _smv_deleter 1s
706 _smv_setup_dl_tensor(&dlm_tensor.dl_tensor, view) 1s
707 return 0
710cdef object _smv_make_py_capsule(StridedMemoryView view, bint versioned):
711 cdef DLManagedTensor* dlm_tensor = NULL 1as
712 cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL 1as
713 cdef object capsule = None 1as
714 cdef void* tensor_ptr = NULL 1as
715 cdef const char* capsule_name
716 try: 1as
717 if versioned: 1as
718 dlm_tensor_ver = _smv_allocate_dlm_tensor_versioned() 1a
719 _smv_fill_managed_tensor_versioned(dlm_tensor_ver, view) 1a
720 tensor_ptr = <void*>dlm_tensor_ver 1a
721 capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME 1a
722 else:
723 dlm_tensor = _smv_allocate_dlm_tensor() 1s
724 _smv_fill_managed_tensor(dlm_tensor, view) 1s
725 tensor_ptr = <void*>dlm_tensor
726 capsule_name = DLPACK_TENSOR_UNUSED_NAME
727 capsule = cpython.PyCapsule_New(tensor_ptr, capsule_name, _smv_pycapsule_deleter) 1a
728 except Exception: 1s
729 if capsule is None: 1s
730 _smv_deleter(dlm_tensor) 1s
731 _smv_versioned_deleter(dlm_tensor_ver) 1s
732 raise 1s
733 return capsule 1a
736cdef inline StridedMemoryView _smv_from_dlpack_capsule(object capsule, object exporting_obj):
737 cdef void* data = NULL
738 cdef DLTensor* dl_tensor = NULL
739 cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL
740 cdef DLManagedTensor* dlm_tensor = NULL
741 cdef bint is_readonly = False
742 cdef const char* used_name = NULL
743 if cpython.PyCapsule_IsValid(capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
744 data = cpython.PyCapsule_GetPointer(capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
745 dlm_tensor_ver = <DLManagedTensorVersioned*>data
746 dl_tensor = &dlm_tensor_ver.dl_tensor
747 is_readonly = bool((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0)
748 used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
749 elif cpython.PyCapsule_IsValid(capsule, DLPACK_TENSOR_UNUSED_NAME):
750 data = cpython.PyCapsule_GetPointer(capsule, DLPACK_TENSOR_UNUSED_NAME)
751 dlm_tensor = <DLManagedTensor*>data
752 dl_tensor = &dlm_tensor.dl_tensor
753 is_readonly = False
754 used_name = DLPACK_TENSOR_USED_NAME
755 else:
756 raise BufferError("Invalid DLPack capsule")
758 cpython.PyCapsule_SetName(capsule, used_name)
760 cdef StridedMemoryView view = StridedMemoryView.__new__(StridedMemoryView)
761 view.dl_tensor = dl_tensor
762 view.metadata = capsule
763 view.ptr = <intptr_t>(dl_tensor.data) + <intptr_t>(dl_tensor.byte_offset)
764 view.readonly = is_readonly
765 view.exporting_obj = exporting_obj
766 if dl_tensor.device.device_type == _kDLCPU:
767 view.device_id = -1
768 view.is_device_accessible = False
769 elif dl_tensor.device.device_type in (_kDLCUDA, _kDLCUDAHost, _kDLCUDAManaged):
770 view.device_id = dl_tensor.device.device_id
771 view.is_device_accessible = True
772 else:
773 raise BufferError("device not supported")
774 return view
777cdef int _smv_managed_tensor_allocator(
778 DLTensor* prototype,
779 DLManagedTensorVersioned** out,
780 void* error_ctx,
781 void (*SetError)(void* error_ctx, const char* kind, const char* message) noexcept,
782) noexcept with gil:
783 if out != NULL:
784 out[0] = NULL
785 if SetError != NULL:
786 SetError(error_ctx, b"NotImplementedError", b"managed_tensor_allocator is not supported by StridedMemoryView")
787 cpython.PyErr_SetString(NotImplementedError, b"managed_tensor_allocator is not supported by StridedMemoryView")
788 return -1
791cdef int _smv_managed_tensor_from_py_object_no_sync(
792 void* py_object,
793 DLManagedTensorVersioned** out,
794) noexcept with gil:
795 cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL
796 if out == NULL:
797 cpython.PyErr_SetString(RuntimeError, b"out cannot be NULL")
798 return -1
799 out[0] = NULL
800 cdef object obj = <object>py_object
801 if not isinstance(obj, StridedMemoryView):
802 cpython.PyErr_SetString(TypeError, b"py_object must be a StridedMemoryView")
803 return -1
804 try:
805 dlm_tensor_ver = _smv_allocate_dlm_tensor_versioned()
806 _smv_fill_managed_tensor_versioned(dlm_tensor_ver, <StridedMemoryView>obj)
807 except Exception:
808 _smv_versioned_deleter(dlm_tensor_ver)
809 return -1
810 out[0] = dlm_tensor_ver
811 return 0
814cdef int _smv_managed_tensor_to_py_object_no_sync(
815 DLManagedTensorVersioned* tensor,
816 void** out_py_object,
817) noexcept with gil:
818 cdef object capsule
819 cdef object py_view
820 if out_py_object == NULL:
821 cpython.PyErr_SetString(RuntimeError, b"out_py_object cannot be NULL")
822 return -1
823 out_py_object[0] = NULL
824 if tensor == NULL:
825 cpython.PyErr_SetString(RuntimeError, b"tensor cannot be NULL")
826 return -1
827 try:
828 capsule = cpython.PyCapsule_New(
829 <void*>tensor,
830 DLPACK_VERSIONED_TENSOR_UNUSED_NAME,
831 _smv_pycapsule_deleter,
832 )
833 py_view = _smv_from_dlpack_capsule(capsule, capsule)
834 cpython.Py_INCREF(py_view)
835 out_py_object[0] = <void*>py_view
836 except Exception:
837 return -1
838 return 0
841cdef int _smv_dltensor_from_py_object_no_sync(
842 void* py_object,
843 DLTensor* out,
844) noexcept with gil:
845 if out == NULL:
846 cpython.PyErr_SetString(RuntimeError, b"out cannot be NULL")
847 return -1
848 cdef object obj = <object>py_object
849 if not isinstance(obj, StridedMemoryView):
850 cpython.PyErr_SetString(TypeError, b"py_object must be a StridedMemoryView")
851 return -1
852 try:
853 _smv_setup_dltensor_borrowed(out, <StridedMemoryView>obj)
854 except Exception:
855 return -1
856 return 0
859cdef int _smv_current_work_stream(
860 _DLDeviceType device_type,
861 int32_t device_id,
862 void** out_current_stream,
863) noexcept with gil:
864 if out_current_stream == NULL:
865 cpython.PyErr_SetString(RuntimeError, b"out_current_stream cannot be NULL")
866 return -1
867 # cuda.core has no global/current stream state today.
868 out_current_stream[0] = NULL
869 return 0
872cdef void _init_smv_dlpack_exchange_api():
873 global _SMV_DLPACK_EXCHANGE_API_INITED
874 if _SMV_DLPACK_EXCHANGE_API_INITED:
875 return
876 _SMV_DLPACK_EXCHANGE_API.header.version.major = DLPACK_MAJOR_VERSION
877 _SMV_DLPACK_EXCHANGE_API.header.version.minor = DLPACK_MINOR_VERSION
878 _SMV_DLPACK_EXCHANGE_API.header.prev_api = NULL
879 _SMV_DLPACK_EXCHANGE_API.managed_tensor_allocator = _smv_managed_tensor_allocator
880 _SMV_DLPACK_EXCHANGE_API.managed_tensor_from_py_object_no_sync = _smv_managed_tensor_from_py_object_no_sync
881 _SMV_DLPACK_EXCHANGE_API.managed_tensor_to_py_object_no_sync = _smv_managed_tensor_to_py_object_no_sync
882 _SMV_DLPACK_EXCHANGE_API.dltensor_from_py_object_no_sync = _smv_dltensor_from_py_object_no_sync
883 _SMV_DLPACK_EXCHANGE_API.current_work_stream = _smv_current_work_stream
884 _SMV_DLPACK_EXCHANGE_API_INITED = True
887_init_smv_dlpack_exchange_api()
888# cdef classes are immutable types in Cython 3, so inject these attributes
889# directly into the type dict.
890(<dict>(<PyTypeObject*>StridedMemoryView).tp_dict)["__dlpack_c_exchange_api__"] = _SMV_DLPACK_EXCHANGE_API_CAPSULE
891(<dict>(<PyTypeObject*>StridedMemoryView).tp_dict)["__c_dlpack_exchange_api__"] = _SMV_DLPACK_EXCHANGE_API_CAPSULE
892PyType_Modified(<PyTypeObject*>StridedMemoryView)
895cdef str get_simple_repr(obj):
896 # TODO: better handling in np.dtype objects
897 cdef object obj_class
898 cdef str obj_repr
899 if isinstance(obj, type):
900 obj_class = obj
901 else:
902 obj_class = obj.__class__
903 if obj_class.__module__ in (None, "builtins"):
904 obj_repr = obj_class.__name__
905 else:
906 obj_repr = f"{obj_class.__module__}.{obj_class.__name__}"
907 return obj_repr
911cdef bint check_has_dlpack(obj) except*:
912 cdef bint has_dlpack
913 if hasattr(obj, "__dlpack__") and hasattr(obj, "__dlpack_device__"): 1zyijklmtuvwxnopqra
914 has_dlpack = True 1zyijklmtuvwxnopqra
915 elif hasattr(obj, "__cuda_array_interface__"):
916 has_dlpack = False
917 else:
918 raise RuntimeError(
919 "the input object does not support any data exchange protocol")
920 return has_dlpack 1zyijklmtuvwxnopqra
923cdef class _StridedMemoryViewProxy:
924 cdef readonly:
925 object obj
926 bint has_dlpack
928 def __init__(self, obj):
929 self.obj = obj 1ijklm
930 self.has_dlpack = check_has_dlpack(obj) 1ijklm
932 cpdef StridedMemoryView view(self, stream_ptr=None):
933 if self.has_dlpack: 1ijklm
934 return StridedMemoryView.from_dlpack(self.obj, stream_ptr) 1ijklm
935 else:
936 return StridedMemoryView.from_cuda_array_interface(self.obj, stream_ptr)
939cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
940 cdef int dldevice, device_id
941 cdef bint is_device_accessible, is_readonly
942 is_device_accessible = False 1zyijklmtuvwxnopqracdefgh
943 dldevice, device_id = obj.__dlpack_device__() 1zyijklmtuvwxnopqracdefgh
944 if dldevice == _kDLCPU: 1zyijklmtuvwxnopqracdefgh
945 assert device_id == 0 1zyijklmtuvwxnopqracdefgh
946 device_id = -1 1zyijklmtuvwxnopqracdefgh
947 if stream_ptr is None: 1zyijklmtuvwxnopqracdefgh
948 raise BufferError("stream=None is ambiguous with view()")
949 elif stream_ptr == -1: 1zyijklmtuvwxnopqracdefgh
950 stream_ptr = None 1zyijklmtuvwxnopqracdefgh
951 elif dldevice == _kDLCUDA:
952 assert device_id >= 0
953 is_device_accessible = True
954 # no need to check other stream values, it's a pass-through
955 if stream_ptr is None:
956 raise BufferError("stream=None is ambiguous with view()")
957 elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
958 is_device_accessible = True
959 # just do a pass-through without any checks, as pinned/managed memory can be
960 # accessed on both host and device
961 else:
962 raise BufferError("device not supported")
964 cdef object capsule
965 try: 1zyijklmtuvwxnopqracdefgh
966 capsule = obj.__dlpack__( 1zyijklmtuvwxnopqracdefgh
967 stream=int(stream_ptr) if stream_ptr else None, 1zyijklmtuvwxnopqracdefgh
968 max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION)) 1zyijklmtuvwxnopqracdefgh
969 except TypeError:
970 capsule = obj.__dlpack__(
971 stream=int(stream_ptr) if stream_ptr else None)
973 cdef void* data = NULL 1zyijklmtuvwxnopqracdefgh
974 cdef DLTensor* dl_tensor
975 cdef DLManagedTensorVersioned* dlm_tensor_ver
976 cdef DLManagedTensor* dlm_tensor
977 cdef const char *used_name
978 if cpython.PyCapsule_IsValid( 1zyijklmtuvwxnopqracdefgh
979 capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
980 data = cpython.PyCapsule_GetPointer( 1zyijklmtuvwxnopqracdefgh
981 capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
982 dlm_tensor_ver = <DLManagedTensorVersioned*>data 1zyijklmtuvwxnopqracdefgh
983 dl_tensor = &dlm_tensor_ver.dl_tensor 1zyijklmtuvwxnopqracdefgh
984 is_readonly = bool((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0) 1zyijklmtuvwxnopqracdefgh
985 used_name = DLPACK_VERSIONED_TENSOR_USED_NAME 1zyijklmtuvwxnopqracdefgh
986 elif cpython.PyCapsule_IsValid(
987 capsule, DLPACK_TENSOR_UNUSED_NAME):
988 data = cpython.PyCapsule_GetPointer(
989 capsule, DLPACK_TENSOR_UNUSED_NAME)
990 dlm_tensor = <DLManagedTensor*>data
991 dl_tensor = &dlm_tensor.dl_tensor
992 is_readonly = False
993 used_name = DLPACK_TENSOR_USED_NAME
994 else:
995 assert False
997 cpython.PyCapsule_SetName(capsule, used_name) 1zyijklmtuvwxnopqracdefgh
999 cdef StridedMemoryView buf = StridedMemoryView() if view is None else view 1zyijklmtuvwxnopqracdefgh
1000 buf.dl_tensor = dl_tensor 1zyijklmtuvwxnopqracdefgh
1001 buf.metadata = capsule 1zyijklmtuvwxnopqracdefgh
1002 buf.ptr = <intptr_t>(dl_tensor.data) 1zyijklmtuvwxnopqracdefgh
1003 buf.device_id = device_id 1zyijklmtuvwxnopqracdefgh
1004 buf.is_device_accessible = is_device_accessible 1zyijklmtuvwxnopqracdefgh
1005 buf.readonly = is_readonly 1zyijklmtuvwxnopqracdefgh
1006 buf.exporting_obj = obj 1zyijklmtuvwxnopqracdefgh
1008 return buf 1zyijklmtuvwxnopqracdefgh
1011@functools.lru_cache
1012def _typestr2dtype(str typestr):
1013 return numpy.dtype(typestr) 1ABCDEFGHIJKL0
1016@functools.lru_cache
1017def _typestr2itemsize(str typestr):
1018 return _typestr2dtype(typestr).itemsize 1ABCDEFGHIJKL0
1021cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
1022 cdef int bits = dtype.bits 1ijklmtuvwxnopqracdefgh
1023 if dtype.lanes != 1: 1ijklmtuvwxnopqracdefgh
1024 # TODO: return a NumPy structured dtype?
1025 raise NotImplementedError(
1026 f'vector dtypes (lanes={dtype.lanes}) is not supported')
1027 if dtype.code == kDLUInt: 1ijklmtuvwxnopqracdefgh
1028 if bits == 8:
1029 np_dtype = numpy.uint8
1030 elif bits == 16:
1031 np_dtype = numpy.uint16
1032 elif bits == 32:
1033 np_dtype = numpy.uint32
1034 elif bits == 64:
1035 np_dtype = numpy.uint64
1036 else:
1037 raise TypeError('uint{} is not supported.'.format(bits))
1038 elif dtype.code == kDLInt:
1039 if bits == 8: 1itnacdefgh
1040 np_dtype = numpy.int8
1041 elif bits == 16:
1042 np_dtype = numpy.int16
1043 elif bits == 32:
1044 np_dtype = numpy.int32 1itnacdefgh
1045 elif bits == 64:
1046 np_dtype = numpy.int64
1047 else:
1048 raise TypeError('int{} is not supported.'.format(bits))
1049 elif dtype.code == kDLFloat:
1050 if bits == 16: 1jklmuvwxopqr
1051 np_dtype = numpy.float16 1lwq
1052 elif bits == 32:
1053 np_dtype = numpy.float32
1054 elif bits == 64:
1055 np_dtype = numpy.float64 1jkmuvxopr
1056 else:
1057 raise TypeError('float{} is not supported.'.format(bits))
1058 elif dtype.code == kDLComplex:
1059 # TODO(leofang): support complex32
1060 if bits == 64:
1061 np_dtype = numpy.complex64
1062 elif bits == 128:
1063 np_dtype = numpy.complex128
1064 else:
1065 raise TypeError('complex{} is not supported.'.format(bits))
1066 elif dtype.code == kDLBool:
1067 if bits == 8:
1068 np_dtype = numpy.bool_
1069 else:
1070 raise TypeError(f'{bits}-bit bool is not supported')
1071 elif dtype.code == kDLBfloat:
1072 if bfloat16 is not None:
1073 np_dtype = numpy.dtype("bfloat16")
1074 else:
1075 raise NotImplementedError(
1076 'Support for bfloat16 within cuda-core requires `ml_dtypes`'
1077 'to be installed.'
1078 )
1079 else:
1080 raise TypeError('Unsupported dtype. dtype code: {}'.format(dtype.code))
1082 # We want the dtype object not just the type object
1083 return numpy.dtype(np_dtype) 1ijklmtuvwxnopqracdefgh
1086cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
1087 cdef dict cai_data = obj.__cuda_array_interface__ 11}0
1088 if cai_data["version"] < 3: 11}0
1089 raise BufferError("only CUDA Array Interface v3 or above is supported")
1090 if cai_data.get("mask") is not None: 11}0
1091 raise BufferError("mask is not supported")
1092 if stream_ptr is None: 11}0
1093 raise BufferError("stream=None is ambiguous with view()")
1095 cdef StridedMemoryView buf = StridedMemoryView() if view is None else view 11}0
1096 buf.exporting_obj = obj 11}0
1097 buf.metadata = cai_data 11}0
1098 buf.dl_tensor = NULL 11}0
1099 # Validate shape/strides/typestr eagerly so constructor paths fail fast.
1100 buf.get_layout() 11}0
1101 buf.ptr, buf.readonly = cai_data["data"] 110
1102 buf.is_device_accessible = True 110
1103 if buf.ptr != 0: 110
1104 buf.device_id = handle_return(
1105 driver.cuPointerGetAttribute(
1106 driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
1107 buf.ptr))
1108 else:
1109 buf.device_id = handle_return(driver.cuCtxGetDevice()) 110
1111 cdef intptr_t producer_s, consumer_s
1112 cdef EventHandle h_event
1113 stream_ptr = int(stream_ptr) 110
1114 if stream_ptr != -1: 110
1115 stream = cai_data.get("stream")
1116 if stream is not None:
1117 producer_s = <intptr_t>(stream)
1118 consumer_s = <intptr_t>(stream_ptr)
1119 assert producer_s > 0
1120 # establish stream order
1121 if producer_s != consumer_s:
1122 with nogil:
1123 h_event = create_event_handle_noctx(cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING)
1124 HANDLE_RETURN(cydriver.cuEventRecord(
1125 as_cu(h_event), <cydriver.CUstream>producer_s))
1126 HANDLE_RETURN(cydriver.cuStreamWaitEvent(
1127 <cydriver.CUstream>consumer_s, as_cu(h_event), 0))
1129 return buf 110
1132cpdef StridedMemoryView view_as_array_interface(obj, view=None):
1133 cdef dict data = obj.__array_interface__ 1ABCDOEPFGHIJQRKSTUVWXYZL~
1134 if data["version"] < 3: 1ABCDOEPFGHIJQRKSTUVWXYZL~
1135 raise BufferError("only NumPy Array Interface v3 or above is supported")
1136 if data.get("mask") is not None: 1ABCDOEPFGHIJQRKSTUVWXYZL~
1137 raise BufferError("mask is not supported")
1139 cdef StridedMemoryView buf = StridedMemoryView() if view is None else view 1ABCDOEPFGHIJQRKSTUVWXYZL~
1140 buf.exporting_obj = obj 1ABCDOEPFGHIJQRKSTUVWXYZL~
1141 buf.metadata = data 1ABCDOEPFGHIJQRKSTUVWXYZL~
1142 buf.dl_tensor = NULL 1ABCDOEPFGHIJQRKSTUVWXYZL~
1143 # Validate shape/strides/typestr eagerly so constructor paths fail fast.
1144 buf.get_layout() 1ABCDOEPFGHIJQRKSTUVWXYZL~
1145 buf.ptr, buf.readonly = data["data"] 1ABCDOEPFGHIJQRKSTUVWXYZL
1146 buf.is_device_accessible = False 1ABCDOEPFGHIJQRKSTUVWXYZL
1147 buf.device_id = handle_return(driver.cuCtxGetDevice()) 1ABCDOEPFGHIJQRKSTUVWXYZL
1148 return buf 1ABCDOEPFGHIJQRKSTUVWXYZL
1151def args_viewable_as_strided_memory(tuple arg_indices):
1152 """
1153 Decorator to create proxy objects to :obj:`StridedMemoryView` for the
1154 specified positional arguments.
1156 This allows array/tensor attributes to be accessed inside the function
1157 implementation, while keeping the function body array-library-agnostic (if
1158 desired).
1160 Inside the decorated function, the specified arguments become instances
1161 of an (undocumented) proxy type, regardless of its original source. A
1162 :obj:`StridedMemoryView` instance can be obtained by passing the (consumer)
1163 stream pointer (as a Python `int`) to the proxies's ``view()`` method. For
1164 example:
1166 .. code-block:: python
1168 @args_viewable_as_strided_memory((1,))
1169 def my_func(arg0, arg1, arg2, stream: Stream):
1170 # arg1 can be any object supporting DLPack or CUDA Array Interface
1171 view = arg1.view(stream.handle)
1172 assert isinstance(view, StridedMemoryView)
1173 ...
1175 Parameters
1176 ----------
1177 arg_indices : tuple
1178 The indices of the target positional arguments.
1179 """
1180 def wrapped_func_with_indices(func): 1ijklm
1181 @functools.wraps(func) 1ijklm
1182 def wrapped_func(*args, **kwargs):
1183 args = list(args) 1ijklm
1184 cdef int idx
1185 for idx in arg_indices: 1ijklm
1186 args[idx] = _StridedMemoryViewProxy(args[idx]) 1ijklm
1187 return func(*args, **kwargs) 1ijklm
1188 return wrapped_func 1ijklm
1189 return wrapped_func_with_indices 1ijklm
1192cdef inline _StridedLayout layout_from_dlpack(DLTensor* dl_tensor):
1193 cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout) 1yijklmtuvwxnopqracdefgh
1194 cdef int nbits = dl_tensor.dtype.bits * dl_tensor.dtype.lanes 1yijklmtuvwxnopqracdefgh
1195 cdef int itemsize = nbits >> 3 1yijklmtuvwxnopqracdefgh
1196 if (itemsize << 3) != nbits: 1yijklmtuvwxnopqracdefgh
1197 raise ValueError("dl_tensor.dtype.bits must be a multiple of 8")
1198 layout.init_from_ptr(dl_tensor.ndim, dl_tensor.shape, dl_tensor.strides, itemsize) 1yijklmtuvwxnopqracdefgh
1199 return layout 1yijklmtuvwxnopqracdefgh
1202cdef _StridedLayout layout_from_cai(object metadata):
1203 cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout) 1ABCDOEPFGHIJQRKSTUVWXYZL~1}0
1204 cdef object shape = metadata["shape"] 1ABCDOEPFGHIJQRKSTUVWXYZL~1}0
1205 cdef object strides = metadata.get("strides") 1ABCDOEPFGHIJQRKSTUVWXYZL~1}0
1206 cdef int itemsize = _typestr2itemsize(metadata["typestr"]) 1ABCDOEPFGHIJQRKSTUVWXYZL~1}0
1207 layout.init_from_tuple(shape, strides, itemsize, True) 1ABCDOEPFGHIJQRKSTUVWXYZL~1}0
1208 return layout 1ABCDOEPFGHIJQRKSTUVWXYZL10
1211cdef inline intptr_t get_data_ptr(object buffer, _StridedLayout layout) except? 0:
1212 return <intptr_t>(int(buffer.handle)) + layout.get_slice_offset_in_bytes() 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1215cdef inline int view_buffer_strided(
1216 StridedMemoryView view,
1217 object buffer,
1218 _StridedLayout layout,
1219 object dtype,
1220 bint is_readonly,
1221) except -1:
1222 if dtype is not None: 22 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbM N ab| s c d e f g h
1223 dtype = numpy.dtype(dtype) 22 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbM N ab| c d e f g h
1224 if dtype.itemsize != layout.itemsize: 22 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbM N ab| c d e f g h
1225 raise ValueError(
1226 f"The dtype's itemsize ({dtype.itemsize}) does not match the layout's "
1227 f"itemsize ({layout.itemsize})."
1228 )
1229 # Check the layout's offset range [min_offset, max_offset] fits
1230 # within the [0, buffer.size - 1] range.
1231 # The required_size_in_bytes fails if min_offset < 0.
1232 # NB. For external memory, both positive and negative offsets can be valid,
1233 # but for a proper check we'd need to know both size and data offset,
1234 # while neither is reported by the packages.
1235 cdef bint is_allocated = buffer.memory_resource is not None 22 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbM N ab| s c d e f g h
1236 if is_allocated and buffer.size < layout.get_required_size_in_bytes(): 22 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { bbM N ab| s c d e f g h
1237 raise ValueError( 2ab
1238 f"Buffer size is too small for the layout. " 2ab
1239 f"Expected at least {layout.get_required_size_in_bytes()} bytes, " 2ab
1240 f"got {buffer.size} bytes." 2ab
1241 )
1242 # set the public attributes
1243 view.ptr = get_data_ptr(buffer, layout) 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1244 view.device_id = buffer.device_id 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1245 view.is_device_accessible = buffer.is_device_accessible 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1246 view.readonly = is_readonly 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1247 view.exporting_obj = view._buffer = buffer 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1248 # no dlpack/cai metadata
1249 view.dl_tensor = NULL 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1250 view.metadata = None 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1251 # we get the layout from the caller
1252 view._layout = layout 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1253 view._dtype = dtype 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh
1254 return 0 123456789!#$%'()*+,-./:;=?@[]^_`{MN|scdefgh