Coverage for cuda/core/_tensor_map.pyx: 29.60%
527 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from libc.stdint cimport intptr_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t
6from libc.stddef cimport size_t
7from cuda.bindings cimport cydriver
8from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
9from cuda.core._dlpack cimport kDLInt, kDLUInt, kDLFloat, kDLBfloat, _kDLCUDA
11import enum
12from dataclasses import dataclass
13from typing import TYPE_CHECKING
15import numpy
17from cuda.core._memoryview import StridedMemoryView
18from cuda.core._utils.cuda_utils import check_or_create_options
20if TYPE_CHECKING:
21 from cuda.core._device import Device
23cdef extern from "_cpp/tensor_map_cccl.h":
24 int cuda_core_cccl_make_tma_descriptor_tiled(
25 void* out_tensor_map,
26 void* data,
27 int device_type,
28 int device_id,
29 int ndim,
30 const int64_t* shape,
31 const int64_t* strides,
32 uint8_t dtype_code,
33 uint8_t dtype_bits,
34 uint16_t dtype_lanes,
35 const int* box_sizes,
36 const int* elem_strides,
37 int interleave_layout,
38 int swizzle,
39 int l2_fetch_size,
40 int oob_fill,
41 char* err,
42 size_t err_cap) nogil
45try:
46 from ml_dtypes import bfloat16 as ml_bfloat16
47except ImportError:
48 ml_bfloat16 = None
51class TensorMapDataType(enum.IntEnum):
52 """Data types for tensor map descriptors.
54 These correspond to the ``CUtensorMapDataType`` driver enum values.
55 """
56 UINT8 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT8
57 UINT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT16
58 UINT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT32
59 INT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_INT32
60 UINT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT64
61 INT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_INT64
62 FLOAT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT16
63 FLOAT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32
64 FLOAT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT64
65 BFLOAT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16
66 FLOAT32_FTZ = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ
67 TFLOAT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32
68 TFLOAT32_FTZ = cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
71class TensorMapInterleave(enum.IntEnum):
72 """Interleave layout for tensor map descriptors.
74 These correspond to the ``CUtensorMapInterleave`` driver enum values.
75 """
76 NONE = cydriver.CU_TENSOR_MAP_INTERLEAVE_NONE
77 INTERLEAVE_16B = cydriver.CU_TENSOR_MAP_INTERLEAVE_16B
78 INTERLEAVE_32B = cydriver.CU_TENSOR_MAP_INTERLEAVE_32B
81class TensorMapSwizzle(enum.IntEnum):
82 """Swizzle mode for tensor map descriptors.
84 These correspond to the ``CUtensorMapSwizzle`` driver enum values.
85 """
86 NONE = cydriver.CU_TENSOR_MAP_SWIZZLE_NONE
87 SWIZZLE_32B = cydriver.CU_TENSOR_MAP_SWIZZLE_32B
88 SWIZZLE_64B = cydriver.CU_TENSOR_MAP_SWIZZLE_64B
89 SWIZZLE_128B = cydriver.CU_TENSOR_MAP_SWIZZLE_128B
92class TensorMapL2Promotion(enum.IntEnum):
93 """L2 promotion mode for tensor map descriptors.
95 These correspond to the ``CUtensorMapL2promotion`` driver enum values.
96 """
97 NONE = cydriver.CU_TENSOR_MAP_L2_PROMOTION_NONE
98 L2_64B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_64B
99 L2_128B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_128B
100 L2_256B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_256B
103class TensorMapOOBFill(enum.IntEnum):
104 """Out-of-bounds fill mode for tensor map descriptors.
106 These correspond to the ``CUtensorMapFloatOOBfill`` driver enum values.
107 """
108 NONE = cydriver.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
109 NAN_REQUEST_ZERO_FMA = cydriver.CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
112IF CUDA_CORE_BUILD_MAJOR >= 13:
113 class TensorMapIm2ColWideMode(enum.IntEnum):
114 """Im2col wide mode for tensor map descriptors.
116 These correspond to the ``CUtensorMapIm2ColWideMode`` driver enum values.
117 Supported on compute capability 10.0+.
118 """
119 W = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W
120 W128 = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W128
121ELSE:
122 class TensorMapIm2ColWideMode(enum.IntEnum):
123 """Im2col wide mode for tensor map descriptors.
125 This enum is always defined for API stability, but the
126 :meth:`TensorMapDescriptor._from_im2col_wide` factory requires a CUDA 13+
127 build and will raise otherwise.
128 """
129 W = 0
130 W128 = 1
133_TMA_DT_UINT8 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT8)
134_TMA_DT_UINT16 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT16)
135_TMA_DT_UINT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT32)
136_TMA_DT_INT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_INT32)
137_TMA_DT_UINT64 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT64)
138_TMA_DT_INT64 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_INT64)
139_TMA_DT_FLOAT16 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT16)
140_TMA_DT_FLOAT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32)
141_TMA_DT_FLOAT64 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT64)
142_TMA_DT_BFLOAT16 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16)
143_TMA_DT_FLOAT32_FTZ = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ)
144_TMA_DT_TFLOAT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32)
145_TMA_DT_TFLOAT32_FTZ = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ)
148def _normalize_tensor_map_data_type(data_type):
149 if data_type is None or isinstance(data_type, TensorMapDataType):
150 return data_type
151 try:
152 return numpy.dtype(data_type)
153 except TypeError as e:
154 raise TypeError(
155 "data_type must be a TensorMapDataType or a numpy/ml_dtypes dtype, "
156 f"got {type(data_type)}") from e
159def _normalize_tensor_map_sequence(name, values):
160 try:
161 values = tuple(values)
162 except TypeError as e:
163 raise TypeError(f"{name} must be a tuple of ints, got {type(values)}") from e
164 for i, value in enumerate(values):
165 if not isinstance(value, int):
166 raise TypeError(f"{name}[{i}] must be an int, got {type(value)}")
167 return values
170def _require_tensor_map_enum(name, value, enum_type):
171 if not isinstance(value, enum_type):
172 raise TypeError(f"{name} must be a {enum_type.__name__}, got {type(value)}")
173 return value
176@dataclass
177class TensorMapDescriptorOptions:
178 """Options for :meth:`cuda.core.StridedMemoryView.as_tensor_map`.
180 Attributes
181 ----------
182 box_dim : tuple[int, ...]
183 Tile size for each tensor dimension, expressed in elements.
184 element_strides : tuple[int, ...], optional
185 Per-dimension element traversal strides.
186 data_type : object, optional
187 Explicit dtype override. Prefer NumPy or ``ml_dtypes`` dtype objects;
188 :class:`TensorMapDataType` remains accepted for compatibility.
189 interleave : TensorMapInterleave, optional
190 Interleave layout. Default ``NONE``.
191 swizzle : TensorMapSwizzle, optional
192 Swizzle mode. Default ``NONE``.
193 l2_promotion : TensorMapL2Promotion, optional
194 L2 promotion mode. Default ``NONE``.
195 oob_fill : TensorMapOOBFill, optional
196 Out-of-bounds fill mode. Default ``NONE``.
197 """
199 box_dim: tuple[int, ...]
200 element_strides: tuple[int, ...] | None = None
201 data_type: object = None
202 interleave: TensorMapInterleave = TensorMapInterleave.NONE
203 swizzle: TensorMapSwizzle = TensorMapSwizzle.NONE
204 l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE
205 oob_fill: TensorMapOOBFill = TensorMapOOBFill.NONE
207 def __post_init__(self) -> None:
208 self.box_dim = _normalize_tensor_map_sequence("box_dim", self.box_dim)
209 if self.element_strides is not None:
210 self.element_strides = _normalize_tensor_map_sequence("element_strides", self.element_strides)
211 self.data_type = _normalize_tensor_map_data_type(self.data_type)
212 self.interleave = _require_tensor_map_enum("interleave", self.interleave, TensorMapInterleave)
213 self.swizzle = _require_tensor_map_enum("swizzle", self.swizzle, TensorMapSwizzle)
214 self.l2_promotion = _require_tensor_map_enum("l2_promotion", self.l2_promotion, TensorMapL2Promotion)
215 self.oob_fill = _require_tensor_map_enum("oob_fill", self.oob_fill, TensorMapOOBFill)
218def _coerce_tensor_map_descriptor_options(
219 box_dim,
220 options,
221 *,
222 element_strides,
223 data_type,
224 interleave,
225 swizzle,
226 l2_promotion,
227 oob_fill,
228):
229 if options is not None:
230 if (
231 box_dim is not None
232 or element_strides is not None
233 or data_type is not None
234 or interleave != TensorMapInterleave.NONE
235 or swizzle != TensorMapSwizzle.NONE
236 or l2_promotion != TensorMapL2Promotion.NONE
237 or oob_fill != TensorMapOOBFill.NONE
238 ):
239 raise TypeError(
240 "Specify either options or the individual tensor map arguments, not both")
241 return check_or_create_options(
242 TensorMapDescriptorOptions,
243 options,
244 "Tensor map descriptor options",
245 )
247 if box_dim is None:
248 raise TypeError("box_dim is required unless options is provided")
250 return TensorMapDescriptorOptions(
251 box_dim=box_dim,
252 element_strides=element_strides,
253 data_type=data_type,
254 interleave=interleave,
255 swizzle=swizzle,
256 l2_promotion=l2_promotion,
257 oob_fill=oob_fill,
258 )
261# Mapping from numpy dtype to TMA data type
262_NUMPY_DTYPE_TO_TMA = {
263 numpy.dtype(numpy.uint8): _TMA_DT_UINT8,
264 numpy.dtype(numpy.uint16): _TMA_DT_UINT16,
265 numpy.dtype(numpy.uint32): _TMA_DT_UINT32,
266 numpy.dtype(numpy.int32): _TMA_DT_INT32,
267 numpy.dtype(numpy.uint64): _TMA_DT_UINT64,
268 numpy.dtype(numpy.int64): _TMA_DT_INT64,
269 numpy.dtype(numpy.float16): _TMA_DT_FLOAT16,
270 numpy.dtype(numpy.float32): _TMA_DT_FLOAT32,
271 numpy.dtype(numpy.float64): _TMA_DT_FLOAT64,
272}
274if ml_bfloat16 is not None:
275 _NUMPY_DTYPE_TO_TMA[numpy.dtype(ml_bfloat16)] = _TMA_DT_BFLOAT16
278# Mapping from TMA data type to element size in bytes
279_TMA_DATA_TYPE_SIZE = {
280 _TMA_DT_UINT8: 1,
281 _TMA_DT_UINT16: 2,
282 _TMA_DT_UINT32: 4,
283 _TMA_DT_INT32: 4,
284 _TMA_DT_UINT64: 8,
285 _TMA_DT_INT64: 8,
286 _TMA_DT_FLOAT16: 2,
287 _TMA_DT_FLOAT32: 4,
288 _TMA_DT_FLOAT64: 8,
289 _TMA_DT_BFLOAT16: 2,
290 _TMA_DT_FLOAT32_FTZ: 4,
291 _TMA_DT_TFLOAT32: 4,
292 _TMA_DT_TFLOAT32_FTZ: 4,
293}
296def _resolve_data_type(view, data_type):
297 """Resolve the TMA data type from an explicit value or the view's dtype."""
299 if data_type is not None:
300 if isinstance(data_type, TensorMapDataType):
301 return int(data_type)
302 dt = _normalize_tensor_map_data_type(data_type)
303 tma_dt = _NUMPY_DTYPE_TO_TMA.get(dt)
304 if tma_dt is None:
305 raise ValueError(
306 f"Unsupported dtype {dt} for TMA; "
307 f"supported dtypes: {list(_NUMPY_DTYPE_TO_TMA.keys())}.")
308 return tma_dt
310 dt = view.dtype
311 if dt is None:
312 raise ValueError(
313 "Cannot infer TMA data type from the tensor; "
314 "please specify data_type explicitly")
316 tma_dt = _NUMPY_DTYPE_TO_TMA.get(dt)
317 if tma_dt is None:
318 raise ValueError(
319 f"Unsupported dtype {dt} for TMA; "
320 f"supported dtypes: {list(_NUMPY_DTYPE_TO_TMA.keys())}. "
321 "You may also specify data_type explicitly.")
323 return tma_dt
326cdef inline bint _tma_dtype_to_dlpack(
327 int tma_dt,
328 uint8_t* out_code,
329 uint8_t* out_bits,
330 uint16_t* out_lanes,
331) noexcept:
332 if tma_dt == _TMA_DT_UINT8:
333 out_code[0] = <uint8_t>kDLUInt
334 out_bits[0] = <uint8_t>8
335 out_lanes[0] = <uint16_t>1
336 return True
337 if tma_dt == _TMA_DT_UINT16:
338 out_code[0] = <uint8_t>kDLUInt
339 out_bits[0] = <uint8_t>16
340 out_lanes[0] = <uint16_t>1
341 return True
342 if tma_dt == _TMA_DT_UINT32:
343 out_code[0] = <uint8_t>kDLUInt
344 out_bits[0] = <uint8_t>32
345 out_lanes[0] = <uint16_t>1
346 return True
347 if tma_dt == _TMA_DT_UINT64:
348 out_code[0] = <uint8_t>kDLUInt
349 out_bits[0] = <uint8_t>64
350 out_lanes[0] = <uint16_t>1
351 return True
352 if tma_dt == _TMA_DT_INT32:
353 out_code[0] = <uint8_t>kDLInt
354 out_bits[0] = <uint8_t>32
355 out_lanes[0] = <uint16_t>1
356 return True
357 if tma_dt == _TMA_DT_INT64:
358 out_code[0] = <uint8_t>kDLInt
359 out_bits[0] = <uint8_t>64
360 out_lanes[0] = <uint16_t>1
361 return True
362 if tma_dt == _TMA_DT_FLOAT16:
363 out_code[0] = <uint8_t>kDLFloat
364 out_bits[0] = <uint8_t>16
365 out_lanes[0] = <uint16_t>1
366 return True
367 if tma_dt == _TMA_DT_FLOAT32:
368 out_code[0] = <uint8_t>kDLFloat
369 out_bits[0] = <uint8_t>32
370 out_lanes[0] = <uint16_t>1
371 return True
372 if tma_dt == _TMA_DT_FLOAT64:
373 out_code[0] = <uint8_t>kDLFloat
374 out_bits[0] = <uint8_t>64
375 out_lanes[0] = <uint16_t>1
376 return True
377 if tma_dt == _TMA_DT_BFLOAT16:
378 out_code[0] = <uint8_t>kDLBfloat
379 out_bits[0] = <uint8_t>16
380 out_lanes[0] = <uint16_t>1
381 return True
382 return False
385cdef inline int _validate_tensor_map_view(view) except -1:
386 if not view.is_device_accessible:
387 raise ValueError("The tensor must be device-accessible")
389 if view.ptr % 16 != 0:
390 raise ValueError(
391 f"Global memory address must be 16-byte aligned, "
392 f"got address 0x{view.ptr:x}")
393 return 0
396def _get_validated_view(tensor):
397 """Obtain a device-accessible StridedMemoryView with a 16-byte-aligned pointer."""
398 if isinstance(tensor, StridedMemoryView):
399 view = tensor
400 else:
401 # stream_ptr=-1: no stream synchronization needed because descriptor
402 # creation only reads tensor metadata, it does not move data.
403 view = StridedMemoryView.from_any_interface(tensor, stream_ptr=-1)
404 _validate_tensor_map_view(view)
405 return view
408def _require_view_device(view, expected_device_id, operation):
409 """Ensure device-local tensors match the current CUDA device.
411 DLPack reports host/managed CUDA memory as ``kDLCUDAHost`` /
412 ``kDLCUDAManaged`` with ``device_id=0`` regardless of the current device,
413 so only true ``kDLCUDA`` tensors are rejected by device-id mismatch.
414 """
415 device_type, device_id = view.__dlpack_device__() 1cdeb
416 if device_type == _kDLCUDA and device_id != expected_device_id: 1cdeb
417 raise ValueError( 1b
418 f"{operation} expects tensor on device {expected_device_id}, got {device_id}") 1ab
419cdef inline intptr_t _get_current_context_ptr() except? 0:
420 cdef cydriver.CUcontext ctx
421 with nogil:
422 HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
423 if ctx == NULL:
424 raise RuntimeError("TensorMapDescriptor requires an active CUDA context")
425 return <intptr_t>ctx
428cdef inline int _get_current_device_id() except -1:
429 cdef cydriver.CUdevice dev
430 with nogil:
431 HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev))
432 return <int>dev
434def _compute_byte_strides(shape, strides, elem_size):
435 """Compute byte strides from element strides or C-contiguous fallback.
437 Returns a tuple of byte strides in row-major order.
438 """
439 if strides is not None:
440 return tuple(s * elem_size for s in strides)
442 # C-contiguous: compute byte strides from shape, innermost first
443 rank = len(shape)
444 byte_strides = []
445 stride = elem_size
446 for i in range(rank - 1, -1, -1):
447 byte_strides.append(stride)
448 stride *= shape[i]
449 byte_strides.reverse()
450 return tuple(byte_strides)
453def _validate_element_strides(element_strides, rank):
454 """Validate or default element_strides to all-ones."""
455 if element_strides is not None:
456 if len(element_strides) != rank:
457 raise ValueError(
458 f"element_strides must have {rank} elements, got {len(element_strides)}")
459 return element_strides
460 return (1,) * rank
463cdef class TensorMapDescriptor:
464 """Describes a TMA (Tensor Memory Accelerator) tensor map for Hopper+ GPUs.
466 A ``TensorMapDescriptor`` wraps the opaque 128-byte ``CUtensorMap`` struct
467 used by the hardware TMA unit for efficient bulk data movement between
468 global and shared memory.
470 Public tiled descriptors are created via
471 :meth:`cuda.core.StridedMemoryView.as_tensor_map`. Specialized
472 ``_from_*`` helpers remain private while this API surface settles, and
473 descriptors can be passed directly to :func:`~cuda.core.launch` as a
474 kernel argument.
475 """
477 def __init__(self):
478 raise RuntimeError( 1f
479 "TensorMapDescriptor cannot be instantiated directly. "
480 "Use StridedMemoryView.as_tensor_map() instead.")
482 cdef void* _get_data_ptr(self):
483 return <void*>&self._tensor_map
485 cdef int _check_context_compat(self) except -1:
486 cdef cydriver.CUcontext current_ctx
487 cdef cydriver.CUdevice current_dev
488 cdef int current_dev_id
489 if self._context == 0 and self._device_id < 0:
490 return 0
491 with nogil:
492 HANDLE_RETURN(cydriver.cuCtxGetCurrent(¤t_ctx))
493 if current_ctx == NULL:
494 raise RuntimeError("TensorMapDescriptor requires an active CUDA context")
495 if self._context != 0 and <intptr_t>current_ctx != self._context:
496 raise RuntimeError(
497 "TensorMapDescriptor was created in a different CUDA context")
498 with nogil:
499 HANDLE_RETURN(cydriver.cuCtxGetDevice(¤t_dev))
500 current_dev_id = <int>current_dev
501 if self._device_id >= 0 and current_dev_id != self._device_id:
502 raise RuntimeError(
503 f"TensorMapDescriptor belongs to device {self._device_id}, "
504 f"but current device is {current_dev_id}")
505 return 0
507 @property
508 def device(self) -> Device | None:
509 """Return the :obj:`~cuda.core.Device` associated with this descriptor."""
510 if self._device_id >= 0:
511 from cuda.core._device import Device
512 return Device(self._device_id)
513 return None
515 @classmethod
516 def _from_tiled(cls, view, box_dim=None, *,
517 options=None,
518 element_strides=None,
519 data_type=None,
520 interleave=TensorMapInterleave.NONE,
521 swizzle=TensorMapSwizzle.NONE,
522 l2_promotion=TensorMapL2Promotion.NONE,
523 oob_fill=TensorMapOOBFill.NONE):
524 """Create a tiled TMA descriptor from a validated view.
526 Parameters
527 ----------
528 view : StridedMemoryView
529 A device-accessible view with a 16-byte-aligned pointer.
530 box_dim : tuple of int, optional
531 The size of each tile dimension (in elements). Must have the
532 same rank as the tensor and each value must be in [1, 256].
533 Specified in the same (row-major) order as the tensor shape.
534 Required unless ``options`` is provided.
535 options : TensorMapDescriptorOptions or mapping, optional
536 Bundled tiled-descriptor options. When provided, do not also pass
537 ``box_dim`` or the individual option kwargs.
538 element_strides : tuple of int, optional
539 Per-dimension element traversal strides. Default is all 1s.
540 Specified in the same (row-major) order as the tensor shape.
541 data_type : dtype-like or TensorMapDataType, optional
542 Explicit dtype override. If ``None``, inferred from the tensor's
543 dtype. Prefer NumPy or ``ml_dtypes`` dtype objects; the enum is
544 accepted for compatibility.
545 interleave : TensorMapInterleave
546 Interleave layout. Default ``NONE``.
547 swizzle : TensorMapSwizzle
548 Swizzle mode. Default ``NONE``.
549 l2_promotion : TensorMapL2Promotion
550 L2 promotion mode. Default ``NONE``.
551 oob_fill : TensorMapOOBFill
552 Out-of-bounds fill mode. Default ``NONE``.
554 Returns
555 -------
556 TensorMapDescriptor
558 Raises
559 ------
560 ValueError
561 If the tensor rank is outside [1, 5], the pointer is not
562 16-byte aligned, or dimension/stride constraints are violated.
563 """
564 cdef TensorMapDescriptor desc = cls.__new__(cls)
566 opts = _coerce_tensor_map_descriptor_options(
567 box_dim,
568 options,
569 element_strides=element_strides,
570 data_type=data_type,
571 interleave=interleave,
572 swizzle=swizzle,
573 l2_promotion=l2_promotion,
574 oob_fill=oob_fill,
575 )
576 box_dim = opts.box_dim
577 element_strides = opts.element_strides
578 data_type = opts.data_type
579 interleave = opts.interleave
580 swizzle = opts.swizzle
581 l2_promotion = opts.l2_promotion
582 oob_fill = opts.oob_fill
584 _validate_tensor_map_view(view)
585 # Keep both the original tensor object and the validated view alive.
586 # For DLPack exporters, the view may hold the owning capsule whose
587 # deleter can free the backing allocation when released.
588 desc._source_ref = view.exporting_obj
589 desc._view_ref = view
590 desc._context = _get_current_context_ptr()
591 desc._device_id = _get_current_device_id()
592 _require_view_device(view, desc._device_id, "TensorMapDescriptor._from_tiled")
594 tma_dt = _resolve_data_type(view, data_type)
595 cdef int c_data_type_int = tma_dt
596 cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int
598 cdef intptr_t global_address = view.ptr
599 shape = view.shape
601 cdef int rank = len(shape)
602 if rank < 1 or rank > 5:
603 raise ValueError(
604 f"Tensor rank must be between 1 and 5, got {rank}")
606 if len(box_dim) != rank:
607 raise ValueError(
608 f"box_dim must have {rank} elements (same as tensor rank), "
609 f"got {len(box_dim)}")
611 for i, bd in enumerate(box_dim):
612 if bd < 1 or bd > 256:
613 raise ValueError(
614 f"box_dim[{i}] must be in [1, 256], got {bd}")
616 cdef bint elem_strides_provided = element_strides is not None
617 element_strides = _validate_element_strides(element_strides, rank)
619 # Reuse CCCL/libcu++'s DLPack -> CUtensorMap conversion when possible.
620 # This avoids maintaining a second, independent validation/encoding implementation.
621 cdef uint8_t dl_code
622 cdef uint8_t dl_bits
623 cdef uint16_t dl_lanes
624 cdef int64_t c_shape[5]
625 cdef int64_t c_strides[5]
626 cdef int c_box_sizes[5]
627 cdef int c_elem_strides[5]
628 cdef const int64_t* c_strides_ptr
629 cdef const int* c_elem_strides_ptr
630 cdef char errbuf[512]
631 cdef int i_cccl
632 cdef int device_type
633 cdef int c_device_id
634 cdef int dl_device_type
635 cdef int dl_device_id
636 cdef int c_cccl_interleave_int
637 cdef int c_cccl_swizzle_int
638 cdef int c_cccl_l2_promotion_int
639 cdef int c_cccl_oob_fill_int
640 cdef int rc
641 if _tma_dtype_to_dlpack(tma_dt, &dl_code, &dl_bits, &dl_lanes):
642 c_strides_ptr = NULL
643 c_elem_strides_ptr = NULL
644 errbuf[0] = 0
646 for i_cccl in range(rank):
647 c_shape[i_cccl] = <int64_t>shape[i_cccl]
648 c_box_sizes[i_cccl] = <int>box_dim[i_cccl]
649 if elem_strides_provided:
650 c_elem_strides[i_cccl] = <int>element_strides[i_cccl]
652 if view.strides is not None:
653 for i_cccl in range(rank):
654 c_strides[i_cccl] = <int64_t>view.strides[i_cccl]
655 c_strides_ptr = &c_strides[0]
657 if elem_strides_provided:
658 c_elem_strides_ptr = &c_elem_strides[0]
660 dl_device_type, dl_device_id = view.__dlpack_device__()
661 device_type = dl_device_type
662 c_device_id = dl_device_id
663 c_cccl_interleave_int = int(interleave)
664 c_cccl_swizzle_int = int(swizzle)
665 c_cccl_l2_promotion_int = int(l2_promotion)
666 c_cccl_oob_fill_int = int(oob_fill)
668 with nogil:
669 rc = cuda_core_cccl_make_tma_descriptor_tiled(
670 <void*>&desc._tensor_map,
671 <void*>global_address,
672 device_type,
673 c_device_id,
674 rank,
675 &c_shape[0],
676 c_strides_ptr,
677 dl_code,
678 dl_bits,
679 dl_lanes,
680 &c_box_sizes[0],
681 c_elem_strides_ptr,
682 c_cccl_interleave_int,
683 c_cccl_swizzle_int,
684 c_cccl_l2_promotion_int,
685 c_cccl_oob_fill_int,
686 &errbuf[0],
687 <size_t>sizeof(errbuf),
688 )
690 if rc == 0:
691 desc._repr_info = {
692 "method": "tiled",
693 "rank": rank,
694 "data_type": TensorMapDataType(tma_dt),
695 "swizzle": swizzle,
696 }
697 return desc
699 msg = errbuf[:].split(b"\0", 1)[0].decode("utf-8", errors="replace")
700 # If CCCL isn't available at build time, fall back to the direct
701 # driver API path to preserve functionality on older toolchains.
702 if "not available at build time" not in msg:
703 raise ValueError(f"Failed to build TMA descriptor via CCCL: {msg}")
705 cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt]
706 byte_strides = _compute_byte_strides(shape, view.strides, elem_size)
708 # Reverse dimensions for column-major cuTensorMap convention
709 # Python/DLPack: row-major (dim 0 = outermost)
710 # cuTensorMap: column-major (dim 0 = innermost)
711 cdef uint64_t[5] c_global_dim
712 cdef uint64_t[4] c_global_strides # rank - 1 elements
713 cdef uint32_t[5] c_box_dim
714 cdef uint32_t[5] c_element_strides
715 cdef int i_c
717 for i_c in range(rank):
718 # Reverse: Python dim i -> cuTensorMap dim (rank - 1 - i)
719 c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c]
720 c_box_dim[i_c] = <uint32_t>box_dim[rank - 1 - i_c]
721 c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c]
723 # globalStrides: rank-1 elements (byte strides for dims 1..N-1 in col-major order)
724 # The innermost stride (dim 0) is implicit = element size
725 for i_c in range(rank - 1):
726 c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c]
728 cdef uint32_t c_rank = <uint32_t>rank
729 cdef int c_interleave_int = int(interleave)
730 cdef int c_swizzle_int = int(swizzle)
731 cdef int c_l2_promotion_int = int(l2_promotion)
732 cdef int c_oob_fill_int = int(oob_fill)
733 cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int
734 cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int
735 cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int
736 cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int
738 with nogil:
739 HANDLE_RETURN(cydriver.cuTensorMapEncodeTiled(
740 &desc._tensor_map,
741 c_data_type,
742 c_rank,
743 <void*>global_address,
744 c_global_dim,
745 c_global_strides,
746 c_box_dim,
747 c_element_strides,
748 c_interleave,
749 c_swizzle,
750 c_l2_promotion,
751 c_oob_fill,
752 ))
754 desc._repr_info = {
755 "method": "tiled",
756 "rank": rank,
757 "data_type": TensorMapDataType(tma_dt),
758 "swizzle": swizzle,
759 }
761 return desc
763 @classmethod
764 def _from_im2col(cls, view, pixel_box_lower_corner, pixel_box_upper_corner,
765 channels_per_pixel, pixels_per_column, *,
766 element_strides=None,
767 data_type=None,
768 interleave=TensorMapInterleave.NONE,
769 swizzle=TensorMapSwizzle.NONE,
770 l2_promotion=TensorMapL2Promotion.NONE,
771 oob_fill=TensorMapOOBFill.NONE):
772 """Create an im2col TMA descriptor from a validated view.
774 Im2col layout is used for convolution-style data access patterns.
776 Parameters
777 ----------
778 view : StridedMemoryView
779 A device-accessible view with a 16-byte-aligned pointer.
780 pixel_box_lower_corner : tuple of int
781 Lower corner of the pixel bounding box for each spatial
782 dimension (rank - 2 elements). Specified in row-major order
783 matching the tensor's spatial dimensions.
784 pixel_box_upper_corner : tuple of int
785 Upper corner of the pixel bounding box for each spatial
786 dimension (rank - 2 elements). Specified in row-major order
787 matching the tensor's spatial dimensions.
788 channels_per_pixel : int
789 Number of channels per pixel.
790 pixels_per_column : int
791 Number of pixels per column.
792 element_strides : tuple of int, optional
793 Per-dimension element traversal strides. Default is all 1s.
794 data_type : dtype-like or TensorMapDataType, optional
795 Explicit dtype override. If ``None``, inferred from the tensor's
796 dtype. Prefer NumPy or ``ml_dtypes`` dtype objects; the enum is
797 accepted for compatibility.
798 interleave : TensorMapInterleave
799 Interleave layout. Default ``NONE``.
800 swizzle : TensorMapSwizzle
801 Swizzle mode. Default ``NONE``.
802 l2_promotion : TensorMapL2Promotion
803 L2 promotion mode. Default ``NONE``.
804 oob_fill : TensorMapOOBFill
805 Out-of-bounds fill mode. Default ``NONE``.
807 Returns
808 -------
809 TensorMapDescriptor
811 Raises
812 ------
813 ValueError
814 If the tensor rank is outside [3, 5], the pointer is not
815 16-byte aligned, or other constraints are violated.
816 """
817 cdef TensorMapDescriptor desc = cls.__new__(cls)
819 _validate_tensor_map_view(view)
820 desc._source_ref = view.exporting_obj
821 desc._view_ref = view
822 desc._context = _get_current_context_ptr()
823 desc._device_id = _get_current_device_id()
824 _require_view_device(view, desc._device_id, "TensorMapDescriptor._from_im2col")
826 tma_dt = _resolve_data_type(view, data_type)
827 cdef int c_data_type_int = tma_dt
828 cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int
830 cdef intptr_t global_address = view.ptr
831 shape = view.shape
833 cdef int rank = len(shape)
834 if rank < 3 or rank > 5:
835 raise ValueError(
836 f"Im2col tensor rank must be between 3 and 5, got {rank}")
838 cdef int n_spatial = rank - 2
839 if len(pixel_box_lower_corner) != n_spatial:
840 raise ValueError(
841 f"pixel_box_lower_corner must have {n_spatial} elements "
842 f"(rank - 2), got {len(pixel_box_lower_corner)}")
843 if len(pixel_box_upper_corner) != n_spatial:
844 raise ValueError(
845 f"pixel_box_upper_corner must have {n_spatial} elements "
846 f"(rank - 2), got {len(pixel_box_upper_corner)}")
848 element_strides = _validate_element_strides(element_strides, rank)
850 cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt]
851 byte_strides = _compute_byte_strides(shape, view.strides, elem_size)
853 # Reverse all dimension arrays for column-major convention
854 cdef uint64_t[5] c_global_dim
855 cdef uint64_t[4] c_global_strides
856 cdef uint32_t[5] c_element_strides
857 cdef int[3] c_pixel_box_lower # max 3 spatial dims (rank 5 - 2)
858 cdef int[3] c_pixel_box_upper
859 cdef int i_c
861 for i_c in range(3):
862 c_pixel_box_lower[i_c] = 0
863 c_pixel_box_upper[i_c] = 0
865 for i_c in range(rank):
866 c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c]
867 c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c]
869 for i_c in range(rank - 1):
870 c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c]
872 # Reverse spatial dimensions for lower/upper corners
873 for i_c in range(n_spatial):
874 c_pixel_box_lower[i_c] = <int>pixel_box_lower_corner[n_spatial - 1 - i_c]
875 c_pixel_box_upper[i_c] = <int>pixel_box_upper_corner[n_spatial - 1 - i_c]
877 cdef uint32_t c_rank = <uint32_t>rank
878 cdef uint32_t c_channels = <uint32_t>channels_per_pixel
879 cdef uint32_t c_pixels = <uint32_t>pixels_per_column
880 cdef int c_interleave_int = int(interleave)
881 cdef int c_swizzle_int = int(swizzle)
882 cdef int c_l2_promotion_int = int(l2_promotion)
883 cdef int c_oob_fill_int = int(oob_fill)
884 cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int
885 cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int
886 cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int
887 cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int
889 with nogil:
890 HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2col(
891 &desc._tensor_map,
892 c_data_type,
893 c_rank,
894 <void*>global_address,
895 c_global_dim,
896 c_global_strides,
897 c_pixel_box_lower,
898 c_pixel_box_upper,
899 c_channels,
900 c_pixels,
901 c_element_strides,
902 c_interleave,
903 c_swizzle,
904 c_l2_promotion,
905 c_oob_fill,
906 ))
908 desc._repr_info = {
909 "method": "im2col",
910 "rank": rank,
911 "data_type": TensorMapDataType(tma_dt),
912 "swizzle": swizzle,
913 }
915 return desc
917 @classmethod
918 def _from_im2col_wide(cls, view, pixel_box_lower_corner_width, pixel_box_upper_corner_width,
919 channels_per_pixel, pixels_per_column, *,
920 element_strides=None,
921 data_type=None,
922 interleave=TensorMapInterleave.NONE,
923 mode=TensorMapIm2ColWideMode.W,
924 swizzle=TensorMapSwizzle.SWIZZLE_128B,
925 l2_promotion=TensorMapL2Promotion.NONE,
926 oob_fill=TensorMapOOBFill.NONE):
927 """Create an im2col-wide TMA descriptor from a validated view.
929 Im2col-wide layout loads elements exclusively along the W (width)
930 dimension. This variant is supported on compute capability 10.0+
931 (Blackwell and later).
933 Parameters
934 ----------
935 view : StridedMemoryView
936 A device-accessible view with a 16-byte-aligned pointer.
937 pixel_box_lower_corner_width : int
938 Lower corner of the pixel bounding box along the W dimension.
939 pixel_box_upper_corner_width : int
940 Upper corner of the pixel bounding box along the W dimension.
941 channels_per_pixel : int
942 Number of channels per pixel.
943 pixels_per_column : int
944 Number of pixels per column.
945 element_strides : tuple of int, optional
946 Per-dimension element traversal strides. Default is all 1s.
947 data_type : dtype-like or TensorMapDataType, optional
948 Explicit dtype override. If ``None``, inferred from the tensor's
949 dtype. Prefer NumPy or ``ml_dtypes`` dtype objects; the enum is
950 accepted for compatibility.
951 interleave : TensorMapInterleave
952 Interleave layout. Default ``NONE``.
953 mode : TensorMapIm2ColWideMode
954 Im2col wide mode. Default ``W``.
955 swizzle : TensorMapSwizzle
956 Swizzle mode. Default ``SWIZZLE_128B``.
957 l2_promotion : TensorMapL2Promotion
958 L2 promotion mode. Default ``NONE``.
959 oob_fill : TensorMapOOBFill
960 Out-of-bounds fill mode. Default ``NONE``.
962 Returns
963 -------
964 TensorMapDescriptor
966 Raises
967 ------
968 ValueError
969 If the tensor rank is outside [3, 5], the pointer is not
970 16-byte aligned, or other constraints are violated.
971 """
972 IF CUDA_CORE_BUILD_MAJOR < 13:
973 raise RuntimeError(
974 "TensorMapDescriptor._from_im2col_wide requires a CUDA 13+ build")
975 ELSE:
976 cdef TensorMapDescriptor desc = cls.__new__(cls)
978 _validate_tensor_map_view(view)
979 desc._source_ref = view.exporting_obj
980 desc._view_ref = view
981 desc._context = _get_current_context_ptr()
982 desc._device_id = _get_current_device_id()
983 _require_view_device(view, desc._device_id, "TensorMapDescriptor._from_im2col_wide")
985 tma_dt = _resolve_data_type(view, data_type)
986 cdef int c_data_type_int = tma_dt
987 cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int
989 cdef intptr_t global_address = view.ptr
990 shape = view.shape
992 cdef int rank = len(shape)
993 if rank < 3 or rank > 5:
994 raise ValueError(
995 f"Im2col-wide tensor rank must be between 3 and 5, got {rank}")
997 element_strides = _validate_element_strides(element_strides, rank)
999 cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt]
1000 byte_strides = _compute_byte_strides(shape, view.strides, elem_size)
1002 # Reverse all dimension arrays for column-major convention
1003 cdef uint64_t[5] c_global_dim
1004 cdef uint64_t[4] c_global_strides
1005 cdef uint32_t[5] c_element_strides
1006 cdef int i_c
1008 for i_c in range(rank):
1009 c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c]
1010 c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c]
1012 for i_c in range(rank - 1):
1013 c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c]
1015 cdef uint32_t c_rank = <uint32_t>rank
1016 cdef int c_lower_w = <int>pixel_box_lower_corner_width
1017 cdef int c_upper_w = <int>pixel_box_upper_corner_width
1018 cdef uint32_t c_channels = <uint32_t>channels_per_pixel
1019 cdef uint32_t c_pixels = <uint32_t>pixels_per_column
1020 cdef int c_interleave_int = int(interleave)
1021 cdef int c_mode_int = int(mode)
1022 cdef int c_swizzle_int = int(swizzle)
1023 cdef int c_l2_promotion_int = int(l2_promotion)
1024 cdef int c_oob_fill_int = int(oob_fill)
1025 cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int
1026 cdef cydriver.CUtensorMapIm2ColWideMode c_mode = <cydriver.CUtensorMapIm2ColWideMode>c_mode_int
1027 cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int
1028 cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int
1029 cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int
1031 with nogil:
1032 HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2colWide(
1033 &desc._tensor_map,
1034 c_data_type,
1035 c_rank,
1036 <void*>global_address,
1037 c_global_dim,
1038 c_global_strides,
1039 c_lower_w,
1040 c_upper_w,
1041 c_channels,
1042 c_pixels,
1043 c_element_strides,
1044 c_interleave,
1045 c_mode,
1046 c_swizzle,
1047 c_l2_promotion,
1048 c_oob_fill,
1049 ))
1051 desc._repr_info = {
1052 "method": "im2col_wide",
1053 "rank": rank,
1054 "data_type": TensorMapDataType(tma_dt),
1055 "swizzle": swizzle,
1056 }
1058 return desc
1060 def replace_address(self, tensor: object) -> None:
1061 """Replace the global memory address in this tensor map descriptor.
1063 This is useful when the tensor data has been reallocated but the
1064 shape, strides, and other parameters remain the same.
1066 Parameters
1067 ----------
1068 tensor : object
1069 Any object supporting DLPack or ``__cuda_array_interface__``,
1070 or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
1071 device-accessible memory with a 16-byte-aligned pointer.
1072 """
1073 self._check_context_compat()
1074 view = _get_validated_view(tensor)
1075 _require_view_device(view, self._device_id, "replace_address")
1077 cdef intptr_t global_address = view.ptr
1079 with nogil:
1080 HANDLE_RETURN(cydriver.cuTensorMapReplaceAddress(
1081 &self._tensor_map,
1082 <void*>global_address,
1083 ))
1085 # Update the source reference only after the driver call succeeds,
1086 # so we don't drop the old tensor (risking a dangling pointer in the
1087 # CUtensorMap struct) if the call fails.
1088 self._source_ref = view.exporting_obj
1089 self._view_ref = view
1091 def __repr__(self) -> str:
1092 info = self._repr_info
1093 if info is None:
1094 return "TensorMapDescriptor()"
1095 parts = []
1096 if "method" in info:
1097 parts.append(info["method"])
1098 if "rank" in info:
1099 parts.append(f"rank={info['rank']}")
1100 if "data_type" in info:
1101 parts.append(f"dtype={info['data_type'].name}")
1102 if "swizzle" in info:
1103 parts.append(f"swizzle={info['swizzle'].name}")
1104 return f"TensorMapDescriptor({', '.join(parts)})"