Coverage for cuda / core / _tensor_map.pyx: 29.39%
524 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from libc.stdint cimport intptr_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t
6from libc.stddef cimport size_t
7from cuda.bindings cimport cydriver
8from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
9from cuda.core._dlpack cimport kDLInt, kDLUInt, kDLFloat, kDLBfloat, _kDLCUDA
11import enum
12from dataclasses import dataclass
14import numpy
16from cuda.core._memoryview import StridedMemoryView
17from cuda.core._utils.cuda_utils import check_or_create_options
19cdef extern from "_cpp/tensor_map_cccl.h":
20 int cuda_core_cccl_make_tma_descriptor_tiled(
21 void* out_tensor_map,
22 void* data,
23 int device_type,
24 int device_id,
25 int ndim,
26 const int64_t* shape,
27 const int64_t* strides,
28 uint8_t dtype_code,
29 uint8_t dtype_bits,
30 uint16_t dtype_lanes,
31 const int* box_sizes,
32 const int* elem_strides,
33 int interleave_layout,
34 int swizzle,
35 int l2_fetch_size,
36 int oob_fill,
37 char* err,
38 size_t err_cap) nogil
41try:
42 from ml_dtypes import bfloat16 as ml_bfloat16
43except ImportError:
44 ml_bfloat16 = None
47class TensorMapDataType(enum.IntEnum):
48 """Data types for tensor map descriptors.
50 These correspond to the ``CUtensorMapDataType`` driver enum values.
51 """
52 UINT8 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT8
53 UINT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT16
54 UINT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT32
55 INT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_INT32
56 UINT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT64
57 INT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_INT64
58 FLOAT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT16
59 FLOAT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32
60 FLOAT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT64
61 BFLOAT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16
62 FLOAT32_FTZ = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ
63 TFLOAT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32
64 TFLOAT32_FTZ = cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
67class TensorMapInterleave(enum.IntEnum):
68 """Interleave layout for tensor map descriptors.
70 These correspond to the ``CUtensorMapInterleave`` driver enum values.
71 """
72 NONE = cydriver.CU_TENSOR_MAP_INTERLEAVE_NONE
73 INTERLEAVE_16B = cydriver.CU_TENSOR_MAP_INTERLEAVE_16B
74 INTERLEAVE_32B = cydriver.CU_TENSOR_MAP_INTERLEAVE_32B
77class TensorMapSwizzle(enum.IntEnum):
78 """Swizzle mode for tensor map descriptors.
80 These correspond to the ``CUtensorMapSwizzle`` driver enum values.
81 """
82 NONE = cydriver.CU_TENSOR_MAP_SWIZZLE_NONE
83 SWIZZLE_32B = cydriver.CU_TENSOR_MAP_SWIZZLE_32B
84 SWIZZLE_64B = cydriver.CU_TENSOR_MAP_SWIZZLE_64B
85 SWIZZLE_128B = cydriver.CU_TENSOR_MAP_SWIZZLE_128B
88class TensorMapL2Promotion(enum.IntEnum):
89 """L2 promotion mode for tensor map descriptors.
91 These correspond to the ``CUtensorMapL2promotion`` driver enum values.
92 """
93 NONE = cydriver.CU_TENSOR_MAP_L2_PROMOTION_NONE
94 L2_64B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_64B
95 L2_128B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_128B
96 L2_256B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_256B
99class TensorMapOOBFill(enum.IntEnum):
100 """Out-of-bounds fill mode for tensor map descriptors.
102 These correspond to the ``CUtensorMapFloatOOBfill`` driver enum values.
103 """
104 NONE = cydriver.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
105 NAN_REQUEST_ZERO_FMA = cydriver.CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
108IF CUDA_CORE_BUILD_MAJOR >= 13:
109 class TensorMapIm2ColWideMode(enum.IntEnum):
110 """Im2col wide mode for tensor map descriptors.
112 These correspond to the ``CUtensorMapIm2ColWideMode`` driver enum values.
113 Supported on compute capability 10.0+.
114 """
115 W = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W
116 W128 = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W128
117ELSE:
118 class TensorMapIm2ColWideMode(enum.IntEnum):
119 """Im2col wide mode for tensor map descriptors.
121 This enum is always defined for API stability, but the
122 :meth:`TensorMapDescriptor._from_im2col_wide` factory requires a CUDA 13+
123 build and will raise otherwise.
124 """
125 W = 0
126 W128 = 1
129_TMA_DT_UINT8 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT8)
130_TMA_DT_UINT16 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT16)
131_TMA_DT_UINT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT32)
132_TMA_DT_INT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_INT32)
133_TMA_DT_UINT64 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT64)
134_TMA_DT_INT64 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_INT64)
135_TMA_DT_FLOAT16 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT16)
136_TMA_DT_FLOAT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32)
137_TMA_DT_FLOAT64 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT64)
138_TMA_DT_BFLOAT16 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16)
139_TMA_DT_FLOAT32_FTZ = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ)
140_TMA_DT_TFLOAT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32)
141_TMA_DT_TFLOAT32_FTZ = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ)
144def _normalize_tensor_map_data_type(data_type):
145 if data_type is None or isinstance(data_type, TensorMapDataType):
146 return data_type
147 try:
148 return numpy.dtype(data_type)
149 except TypeError as e:
150 raise TypeError(
151 "data_type must be a TensorMapDataType or a numpy/ml_dtypes dtype, "
152 f"got {type(data_type)}") from e
155def _normalize_tensor_map_sequence(name, values):
156 try:
157 values = tuple(values)
158 except TypeError as e:
159 raise TypeError(f"{name} must be a tuple of ints, got {type(values)}") from e
160 for i, value in enumerate(values):
161 if not isinstance(value, int):
162 raise TypeError(f"{name}[{i}] must be an int, got {type(value)}")
163 return values
166def _require_tensor_map_enum(name, value, enum_type):
167 if not isinstance(value, enum_type):
168 raise TypeError(f"{name} must be a {enum_type.__name__}, got {type(value)}")
169 return value
172@dataclass
173class TensorMapDescriptorOptions:
174 """Options for :meth:`cuda.core.StridedMemoryView.as_tensor_map`.
176 Attributes
177 ----------
178 box_dim : tuple[int, ...]
179 Tile size for each tensor dimension, expressed in elements.
180 element_strides : tuple[int, ...], optional
181 Per-dimension element traversal strides.
182 data_type : object, optional
183 Explicit dtype override. Prefer NumPy or ``ml_dtypes`` dtype objects;
184 :class:`TensorMapDataType` remains accepted for compatibility.
185 interleave : TensorMapInterleave, optional
186 Interleave layout. Default ``NONE``.
187 swizzle : TensorMapSwizzle, optional
188 Swizzle mode. Default ``NONE``.
189 l2_promotion : TensorMapL2Promotion, optional
190 L2 promotion mode. Default ``NONE``.
191 oob_fill : TensorMapOOBFill, optional
192 Out-of-bounds fill mode. Default ``NONE``.
193 """
195 box_dim: tuple[int, ...]
196 element_strides: tuple[int, ...] | None = None
197 data_type: object = None
198 interleave: TensorMapInterleave = TensorMapInterleave.NONE
199 swizzle: TensorMapSwizzle = TensorMapSwizzle.NONE
200 l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE
201 oob_fill: TensorMapOOBFill = TensorMapOOBFill.NONE
203 def __post_init__(self):
204 self.box_dim = _normalize_tensor_map_sequence("box_dim", self.box_dim)
205 if self.element_strides is not None:
206 self.element_strides = _normalize_tensor_map_sequence("element_strides", self.element_strides)
207 self.data_type = _normalize_tensor_map_data_type(self.data_type)
208 self.interleave = _require_tensor_map_enum("interleave", self.interleave, TensorMapInterleave)
209 self.swizzle = _require_tensor_map_enum("swizzle", self.swizzle, TensorMapSwizzle)
210 self.l2_promotion = _require_tensor_map_enum("l2_promotion", self.l2_promotion, TensorMapL2Promotion)
211 self.oob_fill = _require_tensor_map_enum("oob_fill", self.oob_fill, TensorMapOOBFill)
214def _coerce_tensor_map_descriptor_options(
215 box_dim,
216 options,
217 *,
218 element_strides,
219 data_type,
220 interleave,
221 swizzle,
222 l2_promotion,
223 oob_fill,
224):
225 if options is not None:
226 if (
227 box_dim is not None
228 or element_strides is not None
229 or data_type is not None
230 or interleave != TensorMapInterleave.NONE
231 or swizzle != TensorMapSwizzle.NONE
232 or l2_promotion != TensorMapL2Promotion.NONE
233 or oob_fill != TensorMapOOBFill.NONE
234 ):
235 raise TypeError(
236 "Specify either options or the individual tensor map arguments, not both")
237 return check_or_create_options(
238 TensorMapDescriptorOptions,
239 options,
240 "Tensor map descriptor options",
241 )
243 if box_dim is None:
244 raise TypeError("box_dim is required unless options is provided")
246 return TensorMapDescriptorOptions(
247 box_dim=box_dim,
248 element_strides=element_strides,
249 data_type=data_type,
250 interleave=interleave,
251 swizzle=swizzle,
252 l2_promotion=l2_promotion,
253 oob_fill=oob_fill,
254 )
257# Mapping from numpy dtype to TMA data type
258_NUMPY_DTYPE_TO_TMA = {
259 numpy.dtype(numpy.uint8): _TMA_DT_UINT8,
260 numpy.dtype(numpy.uint16): _TMA_DT_UINT16,
261 numpy.dtype(numpy.uint32): _TMA_DT_UINT32,
262 numpy.dtype(numpy.int32): _TMA_DT_INT32,
263 numpy.dtype(numpy.uint64): _TMA_DT_UINT64,
264 numpy.dtype(numpy.int64): _TMA_DT_INT64,
265 numpy.dtype(numpy.float16): _TMA_DT_FLOAT16,
266 numpy.dtype(numpy.float32): _TMA_DT_FLOAT32,
267 numpy.dtype(numpy.float64): _TMA_DT_FLOAT64,
268}
270if ml_bfloat16 is not None:
271 _NUMPY_DTYPE_TO_TMA[numpy.dtype(ml_bfloat16)] = _TMA_DT_BFLOAT16
274# Mapping from TMA data type to element size in bytes
275_TMA_DATA_TYPE_SIZE = {
276 _TMA_DT_UINT8: 1,
277 _TMA_DT_UINT16: 2,
278 _TMA_DT_UINT32: 4,
279 _TMA_DT_INT32: 4,
280 _TMA_DT_UINT64: 8,
281 _TMA_DT_INT64: 8,
282 _TMA_DT_FLOAT16: 2,
283 _TMA_DT_FLOAT32: 4,
284 _TMA_DT_FLOAT64: 8,
285 _TMA_DT_BFLOAT16: 2,
286 _TMA_DT_FLOAT32_FTZ: 4,
287 _TMA_DT_TFLOAT32: 4,
288 _TMA_DT_TFLOAT32_FTZ: 4,
289}
292def _resolve_data_type(view, data_type):
293 """Resolve the TMA data type from an explicit value or the view's dtype."""
295 if data_type is not None:
296 if isinstance(data_type, TensorMapDataType):
297 return int(data_type)
298 dt = _normalize_tensor_map_data_type(data_type)
299 tma_dt = _NUMPY_DTYPE_TO_TMA.get(dt)
300 if tma_dt is None:
301 raise ValueError(
302 f"Unsupported dtype {dt} for TMA; "
303 f"supported dtypes: {list(_NUMPY_DTYPE_TO_TMA.keys())}.")
304 return tma_dt
306 dt = view.dtype
307 if dt is None:
308 raise ValueError(
309 "Cannot infer TMA data type from the tensor; "
310 "please specify data_type explicitly")
312 tma_dt = _NUMPY_DTYPE_TO_TMA.get(dt)
313 if tma_dt is None:
314 raise ValueError(
315 f"Unsupported dtype {dt} for TMA; "
316 f"supported dtypes: {list(_NUMPY_DTYPE_TO_TMA.keys())}. "
317 "You may also specify data_type explicitly.")
319 return tma_dt
322cdef inline bint _tma_dtype_to_dlpack(
323 int tma_dt,
324 uint8_t* out_code,
325 uint8_t* out_bits,
326 uint16_t* out_lanes,
327) noexcept:
328 if tma_dt == _TMA_DT_UINT8:
329 out_code[0] = <uint8_t>kDLUInt
330 out_bits[0] = <uint8_t>8
331 out_lanes[0] = <uint16_t>1
332 return True
333 if tma_dt == _TMA_DT_UINT16:
334 out_code[0] = <uint8_t>kDLUInt
335 out_bits[0] = <uint8_t>16
336 out_lanes[0] = <uint16_t>1
337 return True
338 if tma_dt == _TMA_DT_UINT32:
339 out_code[0] = <uint8_t>kDLUInt
340 out_bits[0] = <uint8_t>32
341 out_lanes[0] = <uint16_t>1
342 return True
343 if tma_dt == _TMA_DT_UINT64:
344 out_code[0] = <uint8_t>kDLUInt
345 out_bits[0] = <uint8_t>64
346 out_lanes[0] = <uint16_t>1
347 return True
348 if tma_dt == _TMA_DT_INT32:
349 out_code[0] = <uint8_t>kDLInt
350 out_bits[0] = <uint8_t>32
351 out_lanes[0] = <uint16_t>1
352 return True
353 if tma_dt == _TMA_DT_INT64:
354 out_code[0] = <uint8_t>kDLInt
355 out_bits[0] = <uint8_t>64
356 out_lanes[0] = <uint16_t>1
357 return True
358 if tma_dt == _TMA_DT_FLOAT16:
359 out_code[0] = <uint8_t>kDLFloat
360 out_bits[0] = <uint8_t>16
361 out_lanes[0] = <uint16_t>1
362 return True
363 if tma_dt == _TMA_DT_FLOAT32:
364 out_code[0] = <uint8_t>kDLFloat
365 out_bits[0] = <uint8_t>32
366 out_lanes[0] = <uint16_t>1
367 return True
368 if tma_dt == _TMA_DT_FLOAT64:
369 out_code[0] = <uint8_t>kDLFloat
370 out_bits[0] = <uint8_t>64
371 out_lanes[0] = <uint16_t>1
372 return True
373 if tma_dt == _TMA_DT_BFLOAT16:
374 out_code[0] = <uint8_t>kDLBfloat
375 out_bits[0] = <uint8_t>16
376 out_lanes[0] = <uint16_t>1
377 return True
378 return False
381cdef inline int _validate_tensor_map_view(view) except -1:
382 if not view.is_device_accessible:
383 raise ValueError("The tensor must be device-accessible")
385 if view.ptr % 16 != 0:
386 raise ValueError(
387 f"Global memory address must be 16-byte aligned, "
388 f"got address 0x{view.ptr:x}")
389 return 0
392def _get_validated_view(tensor):
393 """Obtain a device-accessible StridedMemoryView with a 16-byte-aligned pointer."""
394 if isinstance(tensor, StridedMemoryView):
395 view = tensor
396 else:
397 # stream_ptr=-1: no stream synchronization needed because descriptor
398 # creation only reads tensor metadata, it does not move data.
399 view = StridedMemoryView.from_any_interface(tensor, stream_ptr=-1)
400 _validate_tensor_map_view(view)
401 return view
404def _require_view_device(view, expected_device_id, operation):
405 """Ensure device-local tensors match the current CUDA device.
407 DLPack reports host/managed CUDA memory as ``kDLCUDAHost`` /
408 ``kDLCUDAManaged`` with ``device_id=0`` regardless of the current device,
409 so only true ``kDLCUDA`` tensors are rejected by device-id mismatch.
410 """
411 device_type, device_id = view.__dlpack_device__() 1cdeb
412 if device_type == _kDLCUDA and device_id != expected_device_id: 1cdeb
413 raise ValueError( 1ab
414 f"{operation} expects tensor on device {expected_device_id}, got {device_id}") 1b
415cdef inline intptr_t _get_current_context_ptr() except? 0:
416 cdef cydriver.CUcontext ctx
417 with nogil:
418 HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
419 if ctx == NULL:
420 raise RuntimeError("TensorMapDescriptor requires an active CUDA context")
421 return <intptr_t>ctx
424cdef inline int _get_current_device_id() except -1:
425 cdef cydriver.CUdevice dev
426 with nogil:
427 HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev))
428 return <int>dev
430def _compute_byte_strides(shape, strides, elem_size):
431 """Compute byte strides from element strides or C-contiguous fallback.
433 Returns a tuple of byte strides in row-major order.
434 """
435 if strides is not None:
436 return tuple(s * elem_size for s in strides)
438 # C-contiguous: compute byte strides from shape, innermost first
439 rank = len(shape)
440 byte_strides = []
441 stride = elem_size
442 for i in range(rank - 1, -1, -1):
443 byte_strides.append(stride)
444 stride *= shape[i]
445 byte_strides.reverse()
446 return tuple(byte_strides)
449def _validate_element_strides(element_strides, rank):
450 """Validate or default element_strides to all-ones."""
451 if element_strides is not None:
452 if len(element_strides) != rank:
453 raise ValueError(
454 f"element_strides must have {rank} elements, got {len(element_strides)}")
455 return element_strides
456 return (1,) * rank
459cdef class TensorMapDescriptor:
460 """Describes a TMA (Tensor Memory Accelerator) tensor map for Hopper+ GPUs.
462 A ``TensorMapDescriptor`` wraps the opaque 128-byte ``CUtensorMap`` struct
463 used by the hardware TMA unit for efficient bulk data movement between
464 global and shared memory.
466 Public tiled descriptors are created via
467 :meth:`cuda.core.StridedMemoryView.as_tensor_map`. Specialized
468 ``_from_*`` helpers remain private while this API surface settles, and
469 descriptors can be passed directly to :func:`~cuda.core.launch` as a
470 kernel argument.
471 """
473 def __init__(self):
474 raise RuntimeError( 1f
475 "TensorMapDescriptor cannot be instantiated directly. "
476 "Use StridedMemoryView.as_tensor_map() instead.")
478 cdef void* _get_data_ptr(self):
479 return <void*>&self._tensor_map
481 cdef int _check_context_compat(self) except -1:
482 cdef cydriver.CUcontext current_ctx
483 cdef cydriver.CUdevice current_dev
484 cdef int current_dev_id
485 if self._context == 0 and self._device_id < 0:
486 return 0
487 with nogil:
488 HANDLE_RETURN(cydriver.cuCtxGetCurrent(¤t_ctx))
489 if current_ctx == NULL:
490 raise RuntimeError("TensorMapDescriptor requires an active CUDA context")
491 if self._context != 0 and <intptr_t>current_ctx != self._context:
492 raise RuntimeError(
493 "TensorMapDescriptor was created in a different CUDA context")
494 with nogil:
495 HANDLE_RETURN(cydriver.cuCtxGetDevice(¤t_dev))
496 current_dev_id = <int>current_dev
497 if self._device_id >= 0 and current_dev_id != self._device_id:
498 raise RuntimeError(
499 f"TensorMapDescriptor belongs to device {self._device_id}, "
500 f"but current device is {current_dev_id}")
501 return 0
503 @property
504 def device(self):
505 """Return the :obj:`~cuda.core.Device` associated with this descriptor."""
506 if self._device_id >= 0:
507 from cuda.core._device import Device
508 return Device(self._device_id)
510 @classmethod
511 def _from_tiled(cls, view, box_dim=None, *,
512 options=None,
513 element_strides=None,
514 data_type=None,
515 interleave=TensorMapInterleave.NONE,
516 swizzle=TensorMapSwizzle.NONE,
517 l2_promotion=TensorMapL2Promotion.NONE,
518 oob_fill=TensorMapOOBFill.NONE):
519 """Create a tiled TMA descriptor from a validated view.
521 Parameters
522 ----------
523 view : StridedMemoryView
524 A device-accessible view with a 16-byte-aligned pointer.
525 box_dim : tuple of int, optional
526 The size of each tile dimension (in elements). Must have the
527 same rank as the tensor and each value must be in [1, 256].
528 Specified in the same (row-major) order as the tensor shape.
529 Required unless ``options`` is provided.
530 options : TensorMapDescriptorOptions or mapping, optional
531 Bundled tiled-descriptor options. When provided, do not also pass
532 ``box_dim`` or the individual option kwargs.
533 element_strides : tuple of int, optional
534 Per-dimension element traversal strides. Default is all 1s.
535 Specified in the same (row-major) order as the tensor shape.
536 data_type : dtype-like or TensorMapDataType, optional
537 Explicit dtype override. If ``None``, inferred from the tensor's
538 dtype. Prefer NumPy or ``ml_dtypes`` dtype objects; the enum is
539 accepted for compatibility.
540 interleave : TensorMapInterleave
541 Interleave layout. Default ``NONE``.
542 swizzle : TensorMapSwizzle
543 Swizzle mode. Default ``NONE``.
544 l2_promotion : TensorMapL2Promotion
545 L2 promotion mode. Default ``NONE``.
546 oob_fill : TensorMapOOBFill
547 Out-of-bounds fill mode. Default ``NONE``.
549 Returns
550 -------
551 TensorMapDescriptor
553 Raises
554 ------
555 ValueError
556 If the tensor rank is outside [1, 5], the pointer is not
557 16-byte aligned, or dimension/stride constraints are violated.
558 """
559 cdef TensorMapDescriptor desc = cls.__new__(cls)
561 opts = _coerce_tensor_map_descriptor_options(
562 box_dim,
563 options,
564 element_strides=element_strides,
565 data_type=data_type,
566 interleave=interleave,
567 swizzle=swizzle,
568 l2_promotion=l2_promotion,
569 oob_fill=oob_fill,
570 )
571 box_dim = opts.box_dim
572 element_strides = opts.element_strides
573 data_type = opts.data_type
574 interleave = opts.interleave
575 swizzle = opts.swizzle
576 l2_promotion = opts.l2_promotion
577 oob_fill = opts.oob_fill
579 _validate_tensor_map_view(view)
580 # Keep both the original tensor object and the validated view alive.
581 # For DLPack exporters, the view may hold the owning capsule whose
582 # deleter can free the backing allocation when released.
583 desc._source_ref = view.exporting_obj
584 desc._view_ref = view
585 desc._context = _get_current_context_ptr()
586 desc._device_id = _get_current_device_id()
587 _require_view_device(view, desc._device_id, "TensorMapDescriptor._from_tiled")
589 tma_dt = _resolve_data_type(view, data_type)
590 cdef int c_data_type_int = tma_dt
591 cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int
593 cdef intptr_t global_address = view.ptr
594 shape = view.shape
596 cdef int rank = len(shape)
597 if rank < 1 or rank > 5:
598 raise ValueError(
599 f"Tensor rank must be between 1 and 5, got {rank}")
601 if len(box_dim) != rank:
602 raise ValueError(
603 f"box_dim must have {rank} elements (same as tensor rank), "
604 f"got {len(box_dim)}")
606 for i, bd in enumerate(box_dim):
607 if bd < 1 or bd > 256:
608 raise ValueError(
609 f"box_dim[{i}] must be in [1, 256], got {bd}")
611 cdef bint elem_strides_provided = element_strides is not None
612 element_strides = _validate_element_strides(element_strides, rank)
614 # Reuse CCCL/libcu++'s DLPack -> CUtensorMap conversion when possible.
615 # This avoids maintaining a second, independent validation/encoding implementation.
616 cdef uint8_t dl_code
617 cdef uint8_t dl_bits
618 cdef uint16_t dl_lanes
619 cdef int64_t c_shape[5]
620 cdef int64_t c_strides[5]
621 cdef int c_box_sizes[5]
622 cdef int c_elem_strides[5]
623 cdef const int64_t* c_strides_ptr
624 cdef const int* c_elem_strides_ptr
625 cdef char errbuf[512]
626 cdef int i_cccl
627 cdef int device_type
628 cdef int c_device_id
629 cdef int dl_device_type
630 cdef int dl_device_id
631 cdef int c_cccl_interleave_int
632 cdef int c_cccl_swizzle_int
633 cdef int c_cccl_l2_promotion_int
634 cdef int c_cccl_oob_fill_int
635 cdef int rc
636 if _tma_dtype_to_dlpack(tma_dt, &dl_code, &dl_bits, &dl_lanes):
637 c_strides_ptr = NULL
638 c_elem_strides_ptr = NULL
639 errbuf[0] = 0
641 for i_cccl in range(rank):
642 c_shape[i_cccl] = <int64_t>shape[i_cccl]
643 c_box_sizes[i_cccl] = <int>box_dim[i_cccl]
644 if elem_strides_provided:
645 c_elem_strides[i_cccl] = <int>element_strides[i_cccl]
647 if view.strides is not None:
648 for i_cccl in range(rank):
649 c_strides[i_cccl] = <int64_t>view.strides[i_cccl]
650 c_strides_ptr = &c_strides[0]
652 if elem_strides_provided:
653 c_elem_strides_ptr = &c_elem_strides[0]
655 dl_device_type, dl_device_id = view.__dlpack_device__()
656 device_type = dl_device_type
657 c_device_id = dl_device_id
658 c_cccl_interleave_int = int(interleave)
659 c_cccl_swizzle_int = int(swizzle)
660 c_cccl_l2_promotion_int = int(l2_promotion)
661 c_cccl_oob_fill_int = int(oob_fill)
663 with nogil:
664 rc = cuda_core_cccl_make_tma_descriptor_tiled(
665 <void*>&desc._tensor_map,
666 <void*>global_address,
667 device_type,
668 c_device_id,
669 rank,
670 &c_shape[0],
671 c_strides_ptr,
672 dl_code,
673 dl_bits,
674 dl_lanes,
675 &c_box_sizes[0],
676 c_elem_strides_ptr,
677 c_cccl_interleave_int,
678 c_cccl_swizzle_int,
679 c_cccl_l2_promotion_int,
680 c_cccl_oob_fill_int,
681 &errbuf[0],
682 <size_t>sizeof(errbuf),
683 )
685 if rc == 0:
686 desc._repr_info = {
687 "method": "tiled",
688 "rank": rank,
689 "data_type": TensorMapDataType(tma_dt),
690 "swizzle": swizzle,
691 }
692 return desc
694 msg = errbuf[:].split(b"\0", 1)[0].decode("utf-8", errors="replace")
695 # If CCCL isn't available at build time, fall back to the direct
696 # driver API path to preserve functionality on older toolchains.
697 if "not available at build time" not in msg:
698 raise ValueError(f"Failed to build TMA descriptor via CCCL: {msg}")
700 cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt]
701 byte_strides = _compute_byte_strides(shape, view.strides, elem_size)
703 # Reverse dimensions for column-major cuTensorMap convention
704 # Python/DLPack: row-major (dim 0 = outermost)
705 # cuTensorMap: column-major (dim 0 = innermost)
706 cdef uint64_t[5] c_global_dim
707 cdef uint64_t[4] c_global_strides # rank - 1 elements
708 cdef uint32_t[5] c_box_dim
709 cdef uint32_t[5] c_element_strides
710 cdef int i_c
712 for i_c in range(rank):
713 # Reverse: Python dim i -> cuTensorMap dim (rank - 1 - i)
714 c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c]
715 c_box_dim[i_c] = <uint32_t>box_dim[rank - 1 - i_c]
716 c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c]
718 # globalStrides: rank-1 elements (byte strides for dims 1..N-1 in col-major order)
719 # The innermost stride (dim 0) is implicit = element size
720 for i_c in range(rank - 1):
721 c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c]
723 cdef uint32_t c_rank = <uint32_t>rank
724 cdef int c_interleave_int = int(interleave)
725 cdef int c_swizzle_int = int(swizzle)
726 cdef int c_l2_promotion_int = int(l2_promotion)
727 cdef int c_oob_fill_int = int(oob_fill)
728 cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int
729 cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int
730 cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int
731 cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int
733 with nogil:
734 HANDLE_RETURN(cydriver.cuTensorMapEncodeTiled(
735 &desc._tensor_map,
736 c_data_type,
737 c_rank,
738 <void*>global_address,
739 c_global_dim,
740 c_global_strides,
741 c_box_dim,
742 c_element_strides,
743 c_interleave,
744 c_swizzle,
745 c_l2_promotion,
746 c_oob_fill,
747 ))
749 desc._repr_info = {
750 "method": "tiled",
751 "rank": rank,
752 "data_type": TensorMapDataType(tma_dt),
753 "swizzle": swizzle,
754 }
756 return desc
758 @classmethod
759 def _from_im2col(cls, view, pixel_box_lower_corner, pixel_box_upper_corner,
760 channels_per_pixel, pixels_per_column, *,
761 element_strides=None,
762 data_type=None,
763 interleave=TensorMapInterleave.NONE,
764 swizzle=TensorMapSwizzle.NONE,
765 l2_promotion=TensorMapL2Promotion.NONE,
766 oob_fill=TensorMapOOBFill.NONE):
767 """Create an im2col TMA descriptor from a validated view.
769 Im2col layout is used for convolution-style data access patterns.
771 Parameters
772 ----------
773 view : StridedMemoryView
774 A device-accessible view with a 16-byte-aligned pointer.
775 pixel_box_lower_corner : tuple of int
776 Lower corner of the pixel bounding box for each spatial
777 dimension (rank - 2 elements). Specified in row-major order
778 matching the tensor's spatial dimensions.
779 pixel_box_upper_corner : tuple of int
780 Upper corner of the pixel bounding box for each spatial
781 dimension (rank - 2 elements). Specified in row-major order
782 matching the tensor's spatial dimensions.
783 channels_per_pixel : int
784 Number of channels per pixel.
785 pixels_per_column : int
786 Number of pixels per column.
787 element_strides : tuple of int, optional
788 Per-dimension element traversal strides. Default is all 1s.
789 data_type : dtype-like or TensorMapDataType, optional
790 Explicit dtype override. If ``None``, inferred from the tensor's
791 dtype. Prefer NumPy or ``ml_dtypes`` dtype objects; the enum is
792 accepted for compatibility.
793 interleave : TensorMapInterleave
794 Interleave layout. Default ``NONE``.
795 swizzle : TensorMapSwizzle
796 Swizzle mode. Default ``NONE``.
797 l2_promotion : TensorMapL2Promotion
798 L2 promotion mode. Default ``NONE``.
799 oob_fill : TensorMapOOBFill
800 Out-of-bounds fill mode. Default ``NONE``.
802 Returns
803 -------
804 TensorMapDescriptor
806 Raises
807 ------
808 ValueError
809 If the tensor rank is outside [3, 5], the pointer is not
810 16-byte aligned, or other constraints are violated.
811 """
812 cdef TensorMapDescriptor desc = cls.__new__(cls)
814 _validate_tensor_map_view(view)
815 desc._source_ref = view.exporting_obj
816 desc._view_ref = view
817 desc._context = _get_current_context_ptr()
818 desc._device_id = _get_current_device_id()
819 _require_view_device(view, desc._device_id, "TensorMapDescriptor._from_im2col")
821 tma_dt = _resolve_data_type(view, data_type)
822 cdef int c_data_type_int = tma_dt
823 cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int
825 cdef intptr_t global_address = view.ptr
826 shape = view.shape
828 cdef int rank = len(shape)
829 if rank < 3 or rank > 5:
830 raise ValueError(
831 f"Im2col tensor rank must be between 3 and 5, got {rank}")
833 cdef int n_spatial = rank - 2
834 if len(pixel_box_lower_corner) != n_spatial:
835 raise ValueError(
836 f"pixel_box_lower_corner must have {n_spatial} elements "
837 f"(rank - 2), got {len(pixel_box_lower_corner)}")
838 if len(pixel_box_upper_corner) != n_spatial:
839 raise ValueError(
840 f"pixel_box_upper_corner must have {n_spatial} elements "
841 f"(rank - 2), got {len(pixel_box_upper_corner)}")
843 element_strides = _validate_element_strides(element_strides, rank)
845 cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt]
846 byte_strides = _compute_byte_strides(shape, view.strides, elem_size)
848 # Reverse all dimension arrays for column-major convention
849 cdef uint64_t[5] c_global_dim
850 cdef uint64_t[4] c_global_strides
851 cdef uint32_t[5] c_element_strides
852 cdef int[3] c_pixel_box_lower # max 3 spatial dims (rank 5 - 2)
853 cdef int[3] c_pixel_box_upper
854 cdef int i_c
856 for i_c in range(3):
857 c_pixel_box_lower[i_c] = 0
858 c_pixel_box_upper[i_c] = 0
860 for i_c in range(rank):
861 c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c]
862 c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c]
864 for i_c in range(rank - 1):
865 c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c]
867 # Reverse spatial dimensions for lower/upper corners
868 for i_c in range(n_spatial):
869 c_pixel_box_lower[i_c] = <int>pixel_box_lower_corner[n_spatial - 1 - i_c]
870 c_pixel_box_upper[i_c] = <int>pixel_box_upper_corner[n_spatial - 1 - i_c]
872 cdef uint32_t c_rank = <uint32_t>rank
873 cdef uint32_t c_channels = <uint32_t>channels_per_pixel
874 cdef uint32_t c_pixels = <uint32_t>pixels_per_column
875 cdef int c_interleave_int = int(interleave)
876 cdef int c_swizzle_int = int(swizzle)
877 cdef int c_l2_promotion_int = int(l2_promotion)
878 cdef int c_oob_fill_int = int(oob_fill)
879 cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int
880 cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int
881 cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int
882 cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int
884 with nogil:
885 HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2col(
886 &desc._tensor_map,
887 c_data_type,
888 c_rank,
889 <void*>global_address,
890 c_global_dim,
891 c_global_strides,
892 c_pixel_box_lower,
893 c_pixel_box_upper,
894 c_channels,
895 c_pixels,
896 c_element_strides,
897 c_interleave,
898 c_swizzle,
899 c_l2_promotion,
900 c_oob_fill,
901 ))
903 desc._repr_info = {
904 "method": "im2col",
905 "rank": rank,
906 "data_type": TensorMapDataType(tma_dt),
907 "swizzle": swizzle,
908 }
910 return desc
912 @classmethod
913 def _from_im2col_wide(cls, view, pixel_box_lower_corner_width, pixel_box_upper_corner_width,
914 channels_per_pixel, pixels_per_column, *,
915 element_strides=None,
916 data_type=None,
917 interleave=TensorMapInterleave.NONE,
918 mode=TensorMapIm2ColWideMode.W,
919 swizzle=TensorMapSwizzle.SWIZZLE_128B,
920 l2_promotion=TensorMapL2Promotion.NONE,
921 oob_fill=TensorMapOOBFill.NONE):
922 """Create an im2col-wide TMA descriptor from a validated view.
924 Im2col-wide layout loads elements exclusively along the W (width)
925 dimension. This variant is supported on compute capability 10.0+
926 (Blackwell and later).
928 Parameters
929 ----------
930 view : StridedMemoryView
931 A device-accessible view with a 16-byte-aligned pointer.
932 pixel_box_lower_corner_width : int
933 Lower corner of the pixel bounding box along the W dimension.
934 pixel_box_upper_corner_width : int
935 Upper corner of the pixel bounding box along the W dimension.
936 channels_per_pixel : int
937 Number of channels per pixel.
938 pixels_per_column : int
939 Number of pixels per column.
940 element_strides : tuple of int, optional
941 Per-dimension element traversal strides. Default is all 1s.
942 data_type : dtype-like or TensorMapDataType, optional
943 Explicit dtype override. If ``None``, inferred from the tensor's
944 dtype. Prefer NumPy or ``ml_dtypes`` dtype objects; the enum is
945 accepted for compatibility.
946 interleave : TensorMapInterleave
947 Interleave layout. Default ``NONE``.
948 mode : TensorMapIm2ColWideMode
949 Im2col wide mode. Default ``W``.
950 swizzle : TensorMapSwizzle
951 Swizzle mode. Default ``SWIZZLE_128B``.
952 l2_promotion : TensorMapL2Promotion
953 L2 promotion mode. Default ``NONE``.
954 oob_fill : TensorMapOOBFill
955 Out-of-bounds fill mode. Default ``NONE``.
957 Returns
958 -------
959 TensorMapDescriptor
961 Raises
962 ------
963 ValueError
964 If the tensor rank is outside [3, 5], the pointer is not
965 16-byte aligned, or other constraints are violated.
966 """
967 IF CUDA_CORE_BUILD_MAJOR < 13:
968 raise RuntimeError(
969 "TensorMapDescriptor._from_im2col_wide requires a CUDA 13+ build")
970 ELSE:
971 cdef TensorMapDescriptor desc = cls.__new__(cls)
973 _validate_tensor_map_view(view)
974 desc._source_ref = view.exporting_obj
975 desc._view_ref = view
976 desc._context = _get_current_context_ptr()
977 desc._device_id = _get_current_device_id()
978 _require_view_device(view, desc._device_id, "TensorMapDescriptor._from_im2col_wide")
980 tma_dt = _resolve_data_type(view, data_type)
981 cdef int c_data_type_int = tma_dt
982 cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int
984 cdef intptr_t global_address = view.ptr
985 shape = view.shape
987 cdef int rank = len(shape)
988 if rank < 3 or rank > 5:
989 raise ValueError(
990 f"Im2col-wide tensor rank must be between 3 and 5, got {rank}")
992 element_strides = _validate_element_strides(element_strides, rank)
994 cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt]
995 byte_strides = _compute_byte_strides(shape, view.strides, elem_size)
997 # Reverse all dimension arrays for column-major convention
998 cdef uint64_t[5] c_global_dim
999 cdef uint64_t[4] c_global_strides
1000 cdef uint32_t[5] c_element_strides
1001 cdef int i_c
1003 for i_c in range(rank):
1004 c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c]
1005 c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c]
1007 for i_c in range(rank - 1):
1008 c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c]
1010 cdef uint32_t c_rank = <uint32_t>rank
1011 cdef int c_lower_w = <int>pixel_box_lower_corner_width
1012 cdef int c_upper_w = <int>pixel_box_upper_corner_width
1013 cdef uint32_t c_channels = <uint32_t>channels_per_pixel
1014 cdef uint32_t c_pixels = <uint32_t>pixels_per_column
1015 cdef int c_interleave_int = int(interleave)
1016 cdef int c_mode_int = int(mode)
1017 cdef int c_swizzle_int = int(swizzle)
1018 cdef int c_l2_promotion_int = int(l2_promotion)
1019 cdef int c_oob_fill_int = int(oob_fill)
1020 cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int
1021 cdef cydriver.CUtensorMapIm2ColWideMode c_mode = <cydriver.CUtensorMapIm2ColWideMode>c_mode_int
1022 cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int
1023 cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int
1024 cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int
1026 with nogil:
1027 HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2colWide(
1028 &desc._tensor_map,
1029 c_data_type,
1030 c_rank,
1031 <void*>global_address,
1032 c_global_dim,
1033 c_global_strides,
1034 c_lower_w,
1035 c_upper_w,
1036 c_channels,
1037 c_pixels,
1038 c_element_strides,
1039 c_interleave,
1040 c_mode,
1041 c_swizzle,
1042 c_l2_promotion,
1043 c_oob_fill,
1044 ))
1046 desc._repr_info = {
1047 "method": "im2col_wide",
1048 "rank": rank,
1049 "data_type": TensorMapDataType(tma_dt),
1050 "swizzle": swizzle,
1051 }
1053 return desc
1055 def replace_address(self, tensor):
1056 """Replace the global memory address in this tensor map descriptor.
1058 This is useful when the tensor data has been reallocated but the
1059 shape, strides, and other parameters remain the same.
1061 Parameters
1062 ----------
1063 tensor : object
1064 Any object supporting DLPack or ``__cuda_array_interface__``,
1065 or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
1066 device-accessible memory with a 16-byte-aligned pointer.
1067 """
1068 self._check_context_compat()
1069 view = _get_validated_view(tensor)
1070 _require_view_device(view, self._device_id, "replace_address")
1072 cdef intptr_t global_address = view.ptr
1074 with nogil:
1075 HANDLE_RETURN(cydriver.cuTensorMapReplaceAddress(
1076 &self._tensor_map,
1077 <void*>global_address,
1078 ))
1080 # Update the source reference only after the driver call succeeds,
1081 # so we don't drop the old tensor (risking a dangling pointer in the
1082 # CUtensorMap struct) if the call fails.
1083 self._source_ref = view.exporting_obj
1084 self._view_ref = view
1086 def __repr__(self):
1087 info = self._repr_info
1088 if info is None:
1089 return "TensorMapDescriptor()"
1090 parts = []
1091 if "method" in info:
1092 parts.append(info["method"])
1093 if "rank" in info:
1094 parts.append(f"rank={info['rank']}")
1095 if "data_type" in info:
1096 parts.append(f"dtype={info['data_type'].name}")
1097 if "swizzle" in info:
1098 parts.append(f"swizzle={info['swizzle'].name}")
1099 return f"TensorMapDescriptor({', '.join(parts)})"