Coverage for cuda/core/_tensor_map.pyx: 29.60%

527 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-13 01:38 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5from libc.stdint cimport intptr_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t 

6from libc.stddef cimport size_t 

7from cuda.bindings cimport cydriver 

8from cuda.core._utils.cuda_utils cimport HANDLE_RETURN 

9from cuda.core._dlpack cimport kDLInt, kDLUInt, kDLFloat, kDLBfloat, _kDLCUDA 

10  

11import enum 

12from dataclasses import dataclass 

13from typing import TYPE_CHECKING 

14  

15import numpy 

16  

17from cuda.core._memoryview import StridedMemoryView 

18from cuda.core._utils.cuda_utils import check_or_create_options 

19  

20if TYPE_CHECKING: 

21 from cuda.core._device import Device 

22  

23cdef extern from "_cpp/tensor_map_cccl.h": 

24 int cuda_core_cccl_make_tma_descriptor_tiled( 

25 void* out_tensor_map, 

26 void* data, 

27 int device_type, 

28 int device_id, 

29 int ndim, 

30 const int64_t* shape, 

31 const int64_t* strides, 

32 uint8_t dtype_code, 

33 uint8_t dtype_bits, 

34 uint16_t dtype_lanes, 

35 const int* box_sizes, 

36 const int* elem_strides, 

37 int interleave_layout, 

38 int swizzle, 

39 int l2_fetch_size, 

40 int oob_fill, 

41 char* err, 

42 size_t err_cap) nogil 

43  

44  

45try: 

46 from ml_dtypes import bfloat16 as ml_bfloat16 

47except ImportError: 

48 ml_bfloat16 = None 

49  

50  

51class TensorMapDataType(enum.IntEnum): 

52 """Data types for tensor map descriptors. 

53  

54 These correspond to the ``CUtensorMapDataType`` driver enum values. 

55 """ 

56 UINT8 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT8 

57 UINT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT16 

58 UINT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT32 

59 INT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_INT32 

60 UINT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT64 

61 INT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_INT64 

62 FLOAT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT16 

63 FLOAT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32 

64 FLOAT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT64 

65 BFLOAT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 

66 FLOAT32_FTZ = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ 

67 TFLOAT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32 

68 TFLOAT32_FTZ = cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ 

69  

70  

71class TensorMapInterleave(enum.IntEnum): 

72 """Interleave layout for tensor map descriptors. 

73  

74 These correspond to the ``CUtensorMapInterleave`` driver enum values. 

75 """ 

76 NONE = cydriver.CU_TENSOR_MAP_INTERLEAVE_NONE 

77 INTERLEAVE_16B = cydriver.CU_TENSOR_MAP_INTERLEAVE_16B 

78 INTERLEAVE_32B = cydriver.CU_TENSOR_MAP_INTERLEAVE_32B 

79  

80  

81class TensorMapSwizzle(enum.IntEnum): 

82 """Swizzle mode for tensor map descriptors. 

83  

84 These correspond to the ``CUtensorMapSwizzle`` driver enum values. 

85 """ 

86 NONE = cydriver.CU_TENSOR_MAP_SWIZZLE_NONE 

87 SWIZZLE_32B = cydriver.CU_TENSOR_MAP_SWIZZLE_32B 

88 SWIZZLE_64B = cydriver.CU_TENSOR_MAP_SWIZZLE_64B 

89 SWIZZLE_128B = cydriver.CU_TENSOR_MAP_SWIZZLE_128B 

90  

91  

92class TensorMapL2Promotion(enum.IntEnum): 

93 """L2 promotion mode for tensor map descriptors. 

94  

95 These correspond to the ``CUtensorMapL2promotion`` driver enum values. 

96 """ 

97 NONE = cydriver.CU_TENSOR_MAP_L2_PROMOTION_NONE 

98 L2_64B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_64B 

99 L2_128B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_128B 

100 L2_256B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_256B 

101  

102  

103class TensorMapOOBFill(enum.IntEnum): 

104 """Out-of-bounds fill mode for tensor map descriptors. 

105  

106 These correspond to the ``CUtensorMapFloatOOBfill`` driver enum values. 

107 """ 

108 NONE = cydriver.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE 

109 NAN_REQUEST_ZERO_FMA = cydriver.CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA 

110  

111  

112IF CUDA_CORE_BUILD_MAJOR >= 13: 

113 class TensorMapIm2ColWideMode(enum.IntEnum): 

114 """Im2col wide mode for tensor map descriptors. 

115  

116 These correspond to the ``CUtensorMapIm2ColWideMode`` driver enum values. 

117 Supported on compute capability 10.0+. 

118 """ 

119 W = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W 

120 W128 = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W128 

121ELSE: 

122 class TensorMapIm2ColWideMode(enum.IntEnum): 

123 """Im2col wide mode for tensor map descriptors. 

124  

125 This enum is always defined for API stability, but the 

126 :meth:`TensorMapDescriptor._from_im2col_wide` factory requires a CUDA 13+ 

127 build and will raise otherwise. 

128 """ 

129 W = 0 

130 W128 = 1 

131  

132  

133_TMA_DT_UINT8 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT8) 

134_TMA_DT_UINT16 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT16) 

135_TMA_DT_UINT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT32) 

136_TMA_DT_INT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_INT32) 

137_TMA_DT_UINT64 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT64) 

138_TMA_DT_INT64 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_INT64) 

139_TMA_DT_FLOAT16 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT16) 

140_TMA_DT_FLOAT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32) 

141_TMA_DT_FLOAT64 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT64) 

142_TMA_DT_BFLOAT16 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16) 

143_TMA_DT_FLOAT32_FTZ = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ) 

144_TMA_DT_TFLOAT32 = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32) 

145_TMA_DT_TFLOAT32_FTZ = int(cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ) 

146  

147  

148def _normalize_tensor_map_data_type(data_type): 

149 if data_type is None or isinstance(data_type, TensorMapDataType): 

150 return data_type 

151 try: 

152 return numpy.dtype(data_type) 

153 except TypeError as e: 

154 raise TypeError( 

155 "data_type must be a TensorMapDataType or a numpy/ml_dtypes dtype, " 

156 f"got {type(data_type)}") from e 

157  

158  

159def _normalize_tensor_map_sequence(name, values): 

160 try: 

161 values = tuple(values) 

162 except TypeError as e: 

163 raise TypeError(f"{name} must be a tuple of ints, got {type(values)}") from e 

164 for i, value in enumerate(values): 

165 if not isinstance(value, int): 

166 raise TypeError(f"{name}[{i}] must be an int, got {type(value)}") 

167 return values 

168  

169  

170def _require_tensor_map_enum(name, value, enum_type): 

171 if not isinstance(value, enum_type): 

172 raise TypeError(f"{name} must be a {enum_type.__name__}, got {type(value)}") 

173 return value 

174  

175  

176@dataclass 

177class TensorMapDescriptorOptions: 

178 """Options for :meth:`cuda.core.StridedMemoryView.as_tensor_map`. 

179  

180 Attributes 

181 ---------- 

182 box_dim : tuple[int, ...] 

183 Tile size for each tensor dimension, expressed in elements. 

184 element_strides : tuple[int, ...], optional 

185 Per-dimension element traversal strides. 

186 data_type : object, optional 

187 Explicit dtype override. Prefer NumPy or ``ml_dtypes`` dtype objects; 

188 :class:`TensorMapDataType` remains accepted for compatibility. 

189 interleave : TensorMapInterleave, optional 

190 Interleave layout. Default ``NONE``. 

191 swizzle : TensorMapSwizzle, optional 

192 Swizzle mode. Default ``NONE``. 

193 l2_promotion : TensorMapL2Promotion, optional 

194 L2 promotion mode. Default ``NONE``. 

195 oob_fill : TensorMapOOBFill, optional 

196 Out-of-bounds fill mode. Default ``NONE``. 

197 """ 

198  

199 box_dim: tuple[int, ...] 

200 element_strides: tuple[int, ...] | None = None 

201 data_type: object = None 

202 interleave: TensorMapInterleave = TensorMapInterleave.NONE 

203 swizzle: TensorMapSwizzle = TensorMapSwizzle.NONE 

204 l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE 

205 oob_fill: TensorMapOOBFill = TensorMapOOBFill.NONE 

206  

207 def __post_init__(self) -> None: 

208 self.box_dim = _normalize_tensor_map_sequence("box_dim", self.box_dim) 

209 if self.element_strides is not None: 

210 self.element_strides = _normalize_tensor_map_sequence("element_strides", self.element_strides) 

211 self.data_type = _normalize_tensor_map_data_type(self.data_type) 

212 self.interleave = _require_tensor_map_enum("interleave", self.interleave, TensorMapInterleave) 

213 self.swizzle = _require_tensor_map_enum("swizzle", self.swizzle, TensorMapSwizzle) 

214 self.l2_promotion = _require_tensor_map_enum("l2_promotion", self.l2_promotion, TensorMapL2Promotion) 

215 self.oob_fill = _require_tensor_map_enum("oob_fill", self.oob_fill, TensorMapOOBFill) 

216  

217  

218def _coerce_tensor_map_descriptor_options( 

219 box_dim, 

220 options, 

221 *, 

222 element_strides, 

223 data_type, 

224 interleave, 

225 swizzle, 

226 l2_promotion, 

227 oob_fill, 

228): 

229 if options is not None: 

230 if ( 

231 box_dim is not None 

232 or element_strides is not None 

233 or data_type is not None 

234 or interleave != TensorMapInterleave.NONE 

235 or swizzle != TensorMapSwizzle.NONE 

236 or l2_promotion != TensorMapL2Promotion.NONE 

237 or oob_fill != TensorMapOOBFill.NONE 

238 ): 

239 raise TypeError( 

240 "Specify either options or the individual tensor map arguments, not both") 

241 return check_or_create_options( 

242 TensorMapDescriptorOptions, 

243 options, 

244 "Tensor map descriptor options", 

245 ) 

246  

247 if box_dim is None: 

248 raise TypeError("box_dim is required unless options is provided") 

249  

250 return TensorMapDescriptorOptions( 

251 box_dim=box_dim, 

252 element_strides=element_strides, 

253 data_type=data_type, 

254 interleave=interleave, 

255 swizzle=swizzle, 

256 l2_promotion=l2_promotion, 

257 oob_fill=oob_fill, 

258 ) 

259  

260  

261# Mapping from numpy dtype to TMA data type 

262_NUMPY_DTYPE_TO_TMA = { 

263 numpy.dtype(numpy.uint8): _TMA_DT_UINT8, 

264 numpy.dtype(numpy.uint16): _TMA_DT_UINT16, 

265 numpy.dtype(numpy.uint32): _TMA_DT_UINT32, 

266 numpy.dtype(numpy.int32): _TMA_DT_INT32, 

267 numpy.dtype(numpy.uint64): _TMA_DT_UINT64, 

268 numpy.dtype(numpy.int64): _TMA_DT_INT64, 

269 numpy.dtype(numpy.float16): _TMA_DT_FLOAT16, 

270 numpy.dtype(numpy.float32): _TMA_DT_FLOAT32, 

271 numpy.dtype(numpy.float64): _TMA_DT_FLOAT64, 

272} 

273  

274if ml_bfloat16 is not None: 

275 _NUMPY_DTYPE_TO_TMA[numpy.dtype(ml_bfloat16)] = _TMA_DT_BFLOAT16 

276  

277  

278# Mapping from TMA data type to element size in bytes 

279_TMA_DATA_TYPE_SIZE = { 

280 _TMA_DT_UINT8: 1, 

281 _TMA_DT_UINT16: 2, 

282 _TMA_DT_UINT32: 4, 

283 _TMA_DT_INT32: 4, 

284 _TMA_DT_UINT64: 8, 

285 _TMA_DT_INT64: 8, 

286 _TMA_DT_FLOAT16: 2, 

287 _TMA_DT_FLOAT32: 4, 

288 _TMA_DT_FLOAT64: 8, 

289 _TMA_DT_BFLOAT16: 2, 

290 _TMA_DT_FLOAT32_FTZ: 4, 

291 _TMA_DT_TFLOAT32: 4, 

292 _TMA_DT_TFLOAT32_FTZ: 4, 

293} 

294  

295  

296def _resolve_data_type(view, data_type): 

297 """Resolve the TMA data type from an explicit value or the view's dtype.""" 

298  

299 if data_type is not None: 

300 if isinstance(data_type, TensorMapDataType): 

301 return int(data_type) 

302 dt = _normalize_tensor_map_data_type(data_type) 

303 tma_dt = _NUMPY_DTYPE_TO_TMA.get(dt) 

304 if tma_dt is None: 

305 raise ValueError( 

306 f"Unsupported dtype {dt} for TMA; " 

307 f"supported dtypes: {list(_NUMPY_DTYPE_TO_TMA.keys())}.") 

308 return tma_dt 

309  

310 dt = view.dtype 

311 if dt is None: 

312 raise ValueError( 

313 "Cannot infer TMA data type from the tensor; " 

314 "please specify data_type explicitly") 

315  

316 tma_dt = _NUMPY_DTYPE_TO_TMA.get(dt) 

317 if tma_dt is None: 

318 raise ValueError( 

319 f"Unsupported dtype {dt} for TMA; " 

320 f"supported dtypes: {list(_NUMPY_DTYPE_TO_TMA.keys())}. " 

321 "You may also specify data_type explicitly.") 

322  

323 return tma_dt 

324  

325  

326cdef inline bint _tma_dtype_to_dlpack( 

327 int tma_dt, 

328 uint8_t* out_code, 

329 uint8_t* out_bits, 

330 uint16_t* out_lanes, 

331) noexcept: 

332 if tma_dt == _TMA_DT_UINT8: 

333 out_code[0] = <uint8_t>kDLUInt 

334 out_bits[0] = <uint8_t>8 

335 out_lanes[0] = <uint16_t>1 

336 return True 

337 if tma_dt == _TMA_DT_UINT16: 

338 out_code[0] = <uint8_t>kDLUInt 

339 out_bits[0] = <uint8_t>16 

340 out_lanes[0] = <uint16_t>1 

341 return True 

342 if tma_dt == _TMA_DT_UINT32: 

343 out_code[0] = <uint8_t>kDLUInt 

344 out_bits[0] = <uint8_t>32 

345 out_lanes[0] = <uint16_t>1 

346 return True 

347 if tma_dt == _TMA_DT_UINT64: 

348 out_code[0] = <uint8_t>kDLUInt 

349 out_bits[0] = <uint8_t>64 

350 out_lanes[0] = <uint16_t>1 

351 return True 

352 if tma_dt == _TMA_DT_INT32: 

353 out_code[0] = <uint8_t>kDLInt 

354 out_bits[0] = <uint8_t>32 

355 out_lanes[0] = <uint16_t>1 

356 return True 

357 if tma_dt == _TMA_DT_INT64: 

358 out_code[0] = <uint8_t>kDLInt 

359 out_bits[0] = <uint8_t>64 

360 out_lanes[0] = <uint16_t>1 

361 return True 

362 if tma_dt == _TMA_DT_FLOAT16: 

363 out_code[0] = <uint8_t>kDLFloat 

364 out_bits[0] = <uint8_t>16 

365 out_lanes[0] = <uint16_t>1 

366 return True 

367 if tma_dt == _TMA_DT_FLOAT32: 

368 out_code[0] = <uint8_t>kDLFloat 

369 out_bits[0] = <uint8_t>32 

370 out_lanes[0] = <uint16_t>1 

371 return True 

372 if tma_dt == _TMA_DT_FLOAT64: 

373 out_code[0] = <uint8_t>kDLFloat 

374 out_bits[0] = <uint8_t>64 

375 out_lanes[0] = <uint16_t>1 

376 return True 

377 if tma_dt == _TMA_DT_BFLOAT16: 

378 out_code[0] = <uint8_t>kDLBfloat 

379 out_bits[0] = <uint8_t>16 

380 out_lanes[0] = <uint16_t>1 

381 return True 

382 return False 

383  

384  

385cdef inline int _validate_tensor_map_view(view) except -1: 

386 if not view.is_device_accessible: 

387 raise ValueError("The tensor must be device-accessible") 

388  

389 if view.ptr % 16 != 0: 

390 raise ValueError( 

391 f"Global memory address must be 16-byte aligned, " 

392 f"got address 0x{view.ptr:x}") 

393 return 0 

394  

395  

396def _get_validated_view(tensor): 

397 """Obtain a device-accessible StridedMemoryView with a 16-byte-aligned pointer.""" 

398 if isinstance(tensor, StridedMemoryView): 

399 view = tensor 

400 else: 

401 # stream_ptr=-1: no stream synchronization needed because descriptor 

402 # creation only reads tensor metadata, it does not move data. 

403 view = StridedMemoryView.from_any_interface(tensor, stream_ptr=-1) 

404 _validate_tensor_map_view(view) 

405 return view 

406  

407  

408def _require_view_device(view, expected_device_id, operation): 

409 """Ensure device-local tensors match the current CUDA device. 

410  

411 DLPack reports host/managed CUDA memory as ``kDLCUDAHost`` / 

412 ``kDLCUDAManaged`` with ``device_id=0`` regardless of the current device, 

413 so only true ``kDLCUDA`` tensors are rejected by device-id mismatch. 

414 """ 

415 device_type, device_id = view.__dlpack_device__() 1cdeb

416 if device_type == _kDLCUDA and device_id != expected_device_id: 1cdeb

417 raise ValueError( 1b

418 f"{operation} expects tensor on device {expected_device_id}, got {device_id}") 1ab

419cdef inline intptr_t _get_current_context_ptr() except? 0: 

420 cdef cydriver.CUcontext ctx 

421 with nogil: 

422 HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) 

423 if ctx == NULL: 

424 raise RuntimeError("TensorMapDescriptor requires an active CUDA context") 

425 return <intptr_t>ctx 

426  

427  

428cdef inline int _get_current_device_id() except -1: 

429 cdef cydriver.CUdevice dev 

430 with nogil: 

431 HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev)) 

432 return <int>dev 

433  

434def _compute_byte_strides(shape, strides, elem_size): 

435 """Compute byte strides from element strides or C-contiguous fallback. 

436  

437 Returns a tuple of byte strides in row-major order. 

438 """ 

439 if strides is not None: 

440 return tuple(s * elem_size for s in strides) 

441  

442 # C-contiguous: compute byte strides from shape, innermost first 

443 rank = len(shape) 

444 byte_strides = [] 

445 stride = elem_size 

446 for i in range(rank - 1, -1, -1): 

447 byte_strides.append(stride) 

448 stride *= shape[i] 

449 byte_strides.reverse() 

450 return tuple(byte_strides) 

451  

452  

453def _validate_element_strides(element_strides, rank): 

454 """Validate or default element_strides to all-ones.""" 

455 if element_strides is not None: 

456 if len(element_strides) != rank: 

457 raise ValueError( 

458 f"element_strides must have {rank} elements, got {len(element_strides)}") 

459 return element_strides 

460 return (1,) * rank 

461  

462  

463cdef class TensorMapDescriptor: 

464 """Describes a TMA (Tensor Memory Accelerator) tensor map for Hopper+ GPUs. 

465  

466 A ``TensorMapDescriptor`` wraps the opaque 128-byte ``CUtensorMap`` struct 

467 used by the hardware TMA unit for efficient bulk data movement between 

468 global and shared memory. 

469  

470 Public tiled descriptors are created via 

471 :meth:`cuda.core.StridedMemoryView.as_tensor_map`. Specialized 

472 ``_from_*`` helpers remain private while this API surface settles, and 

473 descriptors can be passed directly to :func:`~cuda.core.launch` as a 

474 kernel argument. 

475 """ 

476  

477 def __init__(self): 

478 raise RuntimeError( 1f

479 "TensorMapDescriptor cannot be instantiated directly. " 

480 "Use StridedMemoryView.as_tensor_map() instead.") 

481  

482 cdef void* _get_data_ptr(self): 

483 return <void*>&self._tensor_map 

484  

485 cdef int _check_context_compat(self) except -1: 

486 cdef cydriver.CUcontext current_ctx 

487 cdef cydriver.CUdevice current_dev 

488 cdef int current_dev_id 

489 if self._context == 0 and self._device_id < 0: 

490 return 0 

491 with nogil: 

492 HANDLE_RETURN(cydriver.cuCtxGetCurrent(&current_ctx)) 

493 if current_ctx == NULL: 

494 raise RuntimeError("TensorMapDescriptor requires an active CUDA context") 

495 if self._context != 0 and <intptr_t>current_ctx != self._context: 

496 raise RuntimeError( 

497 "TensorMapDescriptor was created in a different CUDA context") 

498 with nogil: 

499 HANDLE_RETURN(cydriver.cuCtxGetDevice(&current_dev)) 

500 current_dev_id = <int>current_dev 

501 if self._device_id >= 0 and current_dev_id != self._device_id: 

502 raise RuntimeError( 

503 f"TensorMapDescriptor belongs to device {self._device_id}, " 

504 f"but current device is {current_dev_id}") 

505 return 0 

506  

507 @property 

508 def device(self) -> Device | None: 

509 """Return the :obj:`~cuda.core.Device` associated with this descriptor.""" 

510 if self._device_id >= 0: 

511 from cuda.core._device import Device 

512 return Device(self._device_id) 

513 return None 

514  

515 @classmethod 

516 def _from_tiled(cls, view, box_dim=None, *, 

517 options=None, 

518 element_strides=None, 

519 data_type=None, 

520 interleave=TensorMapInterleave.NONE, 

521 swizzle=TensorMapSwizzle.NONE, 

522 l2_promotion=TensorMapL2Promotion.NONE, 

523 oob_fill=TensorMapOOBFill.NONE): 

524 """Create a tiled TMA descriptor from a validated view. 

525  

526 Parameters 

527 ---------- 

528 view : StridedMemoryView 

529 A device-accessible view with a 16-byte-aligned pointer. 

530 box_dim : tuple of int, optional 

531 The size of each tile dimension (in elements). Must have the 

532 same rank as the tensor and each value must be in [1, 256]. 

533 Specified in the same (row-major) order as the tensor shape. 

534 Required unless ``options`` is provided. 

535 options : TensorMapDescriptorOptions or mapping, optional 

536 Bundled tiled-descriptor options. When provided, do not also pass 

537 ``box_dim`` or the individual option kwargs. 

538 element_strides : tuple of int, optional 

539 Per-dimension element traversal strides. Default is all 1s. 

540 Specified in the same (row-major) order as the tensor shape. 

541 data_type : dtype-like or TensorMapDataType, optional 

542 Explicit dtype override. If ``None``, inferred from the tensor's 

543 dtype. Prefer NumPy or ``ml_dtypes`` dtype objects; the enum is 

544 accepted for compatibility. 

545 interleave : TensorMapInterleave 

546 Interleave layout. Default ``NONE``. 

547 swizzle : TensorMapSwizzle 

548 Swizzle mode. Default ``NONE``. 

549 l2_promotion : TensorMapL2Promotion 

550 L2 promotion mode. Default ``NONE``. 

551 oob_fill : TensorMapOOBFill 

552 Out-of-bounds fill mode. Default ``NONE``. 

553  

554 Returns 

555 ------- 

556 TensorMapDescriptor 

557  

558 Raises 

559 ------ 

560 ValueError 

561 If the tensor rank is outside [1, 5], the pointer is not 

562 16-byte aligned, or dimension/stride constraints are violated. 

563 """ 

564 cdef TensorMapDescriptor desc = cls.__new__(cls) 

565  

566 opts = _coerce_tensor_map_descriptor_options( 

567 box_dim, 

568 options, 

569 element_strides=element_strides, 

570 data_type=data_type, 

571 interleave=interleave, 

572 swizzle=swizzle, 

573 l2_promotion=l2_promotion, 

574 oob_fill=oob_fill, 

575 ) 

576 box_dim = opts.box_dim 

577 element_strides = opts.element_strides 

578 data_type = opts.data_type 

579 interleave = opts.interleave 

580 swizzle = opts.swizzle 

581 l2_promotion = opts.l2_promotion 

582 oob_fill = opts.oob_fill 

583  

584 _validate_tensor_map_view(view) 

585 # Keep both the original tensor object and the validated view alive. 

586 # For DLPack exporters, the view may hold the owning capsule whose 

587 # deleter can free the backing allocation when released. 

588 desc._source_ref = view.exporting_obj 

589 desc._view_ref = view 

590 desc._context = _get_current_context_ptr() 

591 desc._device_id = _get_current_device_id() 

592 _require_view_device(view, desc._device_id, "TensorMapDescriptor._from_tiled") 

593  

594 tma_dt = _resolve_data_type(view, data_type) 

595 cdef int c_data_type_int = tma_dt 

596 cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int 

597  

598 cdef intptr_t global_address = view.ptr 

599 shape = view.shape 

600  

601 cdef int rank = len(shape) 

602 if rank < 1 or rank > 5: 

603 raise ValueError( 

604 f"Tensor rank must be between 1 and 5, got {rank}") 

605  

606 if len(box_dim) != rank: 

607 raise ValueError( 

608 f"box_dim must have {rank} elements (same as tensor rank), " 

609 f"got {len(box_dim)}") 

610  

611 for i, bd in enumerate(box_dim): 

612 if bd < 1 or bd > 256: 

613 raise ValueError( 

614 f"box_dim[{i}] must be in [1, 256], got {bd}") 

615  

616 cdef bint elem_strides_provided = element_strides is not None 

617 element_strides = _validate_element_strides(element_strides, rank) 

618  

619 # Reuse CCCL/libcu++'s DLPack -> CUtensorMap conversion when possible. 

620 # This avoids maintaining a second, independent validation/encoding implementation. 

621 cdef uint8_t dl_code 

622 cdef uint8_t dl_bits 

623 cdef uint16_t dl_lanes 

624 cdef int64_t c_shape[5] 

625 cdef int64_t c_strides[5] 

626 cdef int c_box_sizes[5] 

627 cdef int c_elem_strides[5] 

628 cdef const int64_t* c_strides_ptr 

629 cdef const int* c_elem_strides_ptr 

630 cdef char errbuf[512] 

631 cdef int i_cccl 

632 cdef int device_type 

633 cdef int c_device_id 

634 cdef int dl_device_type 

635 cdef int dl_device_id 

636 cdef int c_cccl_interleave_int 

637 cdef int c_cccl_swizzle_int 

638 cdef int c_cccl_l2_promotion_int 

639 cdef int c_cccl_oob_fill_int 

640 cdef int rc 

641 if _tma_dtype_to_dlpack(tma_dt, &dl_code, &dl_bits, &dl_lanes): 

642 c_strides_ptr = NULL 

643 c_elem_strides_ptr = NULL 

644 errbuf[0] = 0 

645  

646 for i_cccl in range(rank): 

647 c_shape[i_cccl] = <int64_t>shape[i_cccl] 

648 c_box_sizes[i_cccl] = <int>box_dim[i_cccl] 

649 if elem_strides_provided: 

650 c_elem_strides[i_cccl] = <int>element_strides[i_cccl] 

651  

652 if view.strides is not None: 

653 for i_cccl in range(rank): 

654 c_strides[i_cccl] = <int64_t>view.strides[i_cccl] 

655 c_strides_ptr = &c_strides[0] 

656  

657 if elem_strides_provided: 

658 c_elem_strides_ptr = &c_elem_strides[0] 

659  

660 dl_device_type, dl_device_id = view.__dlpack_device__() 

661 device_type = dl_device_type 

662 c_device_id = dl_device_id 

663 c_cccl_interleave_int = int(interleave) 

664 c_cccl_swizzle_int = int(swizzle) 

665 c_cccl_l2_promotion_int = int(l2_promotion) 

666 c_cccl_oob_fill_int = int(oob_fill) 

667  

668 with nogil: 

669 rc = cuda_core_cccl_make_tma_descriptor_tiled( 

670 <void*>&desc._tensor_map, 

671 <void*>global_address, 

672 device_type, 

673 c_device_id, 

674 rank, 

675 &c_shape[0], 

676 c_strides_ptr, 

677 dl_code, 

678 dl_bits, 

679 dl_lanes, 

680 &c_box_sizes[0], 

681 c_elem_strides_ptr, 

682 c_cccl_interleave_int, 

683 c_cccl_swizzle_int, 

684 c_cccl_l2_promotion_int, 

685 c_cccl_oob_fill_int, 

686 &errbuf[0], 

687 <size_t>sizeof(errbuf), 

688 ) 

689  

690 if rc == 0: 

691 desc._repr_info = { 

692 "method": "tiled", 

693 "rank": rank, 

694 "data_type": TensorMapDataType(tma_dt), 

695 "swizzle": swizzle, 

696 } 

697 return desc 

698  

699 msg = errbuf[:].split(b"\0", 1)[0].decode("utf-8", errors="replace") 

700 # If CCCL isn't available at build time, fall back to the direct 

701 # driver API path to preserve functionality on older toolchains. 

702 if "not available at build time" not in msg: 

703 raise ValueError(f"Failed to build TMA descriptor via CCCL: {msg}") 

704  

705 cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt] 

706 byte_strides = _compute_byte_strides(shape, view.strides, elem_size) 

707  

708 # Reverse dimensions for column-major cuTensorMap convention 

709 # Python/DLPack: row-major (dim 0 = outermost) 

710 # cuTensorMap: column-major (dim 0 = innermost) 

711 cdef uint64_t[5] c_global_dim 

712 cdef uint64_t[4] c_global_strides # rank - 1 elements 

713 cdef uint32_t[5] c_box_dim 

714 cdef uint32_t[5] c_element_strides 

715 cdef int i_c 

716  

717 for i_c in range(rank): 

718 # Reverse: Python dim i -> cuTensorMap dim (rank - 1 - i) 

719 c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c] 

720 c_box_dim[i_c] = <uint32_t>box_dim[rank - 1 - i_c] 

721 c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c] 

722  

723 # globalStrides: rank-1 elements (byte strides for dims 1..N-1 in col-major order) 

724 # The innermost stride (dim 0) is implicit = element size 

725 for i_c in range(rank - 1): 

726 c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c] 

727  

728 cdef uint32_t c_rank = <uint32_t>rank 

729 cdef int c_interleave_int = int(interleave) 

730 cdef int c_swizzle_int = int(swizzle) 

731 cdef int c_l2_promotion_int = int(l2_promotion) 

732 cdef int c_oob_fill_int = int(oob_fill) 

733 cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int 

734 cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int 

735 cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int 

736 cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int 

737  

738 with nogil: 

739 HANDLE_RETURN(cydriver.cuTensorMapEncodeTiled( 

740 &desc._tensor_map, 

741 c_data_type, 

742 c_rank, 

743 <void*>global_address, 

744 c_global_dim, 

745 c_global_strides, 

746 c_box_dim, 

747 c_element_strides, 

748 c_interleave, 

749 c_swizzle, 

750 c_l2_promotion, 

751 c_oob_fill, 

752 )) 

753  

754 desc._repr_info = { 

755 "method": "tiled", 

756 "rank": rank, 

757 "data_type": TensorMapDataType(tma_dt), 

758 "swizzle": swizzle, 

759 } 

760  

761 return desc 

762  

763 @classmethod 

764 def _from_im2col(cls, view, pixel_box_lower_corner, pixel_box_upper_corner, 

765 channels_per_pixel, pixels_per_column, *, 

766 element_strides=None, 

767 data_type=None, 

768 interleave=TensorMapInterleave.NONE, 

769 swizzle=TensorMapSwizzle.NONE, 

770 l2_promotion=TensorMapL2Promotion.NONE, 

771 oob_fill=TensorMapOOBFill.NONE): 

772 """Create an im2col TMA descriptor from a validated view. 

773  

774 Im2col layout is used for convolution-style data access patterns. 

775  

776 Parameters 

777 ---------- 

778 view : StridedMemoryView 

779 A device-accessible view with a 16-byte-aligned pointer. 

780 pixel_box_lower_corner : tuple of int 

781 Lower corner of the pixel bounding box for each spatial 

782 dimension (rank - 2 elements). Specified in row-major order 

783 matching the tensor's spatial dimensions. 

784 pixel_box_upper_corner : tuple of int 

785 Upper corner of the pixel bounding box for each spatial 

786 dimension (rank - 2 elements). Specified in row-major order 

787 matching the tensor's spatial dimensions. 

788 channels_per_pixel : int 

789 Number of channels per pixel. 

790 pixels_per_column : int 

791 Number of pixels per column. 

792 element_strides : tuple of int, optional 

793 Per-dimension element traversal strides. Default is all 1s. 

794 data_type : dtype-like or TensorMapDataType, optional 

795 Explicit dtype override. If ``None``, inferred from the tensor's 

796 dtype. Prefer NumPy or ``ml_dtypes`` dtype objects; the enum is 

797 accepted for compatibility. 

798 interleave : TensorMapInterleave 

799 Interleave layout. Default ``NONE``. 

800 swizzle : TensorMapSwizzle 

801 Swizzle mode. Default ``NONE``. 

802 l2_promotion : TensorMapL2Promotion 

803 L2 promotion mode. Default ``NONE``. 

804 oob_fill : TensorMapOOBFill 

805 Out-of-bounds fill mode. Default ``NONE``. 

806  

807 Returns 

808 ------- 

809 TensorMapDescriptor 

810  

811 Raises 

812 ------ 

813 ValueError 

814 If the tensor rank is outside [3, 5], the pointer is not 

815 16-byte aligned, or other constraints are violated. 

816 """ 

817 cdef TensorMapDescriptor desc = cls.__new__(cls) 

818  

819 _validate_tensor_map_view(view) 

820 desc._source_ref = view.exporting_obj 

821 desc._view_ref = view 

822 desc._context = _get_current_context_ptr() 

823 desc._device_id = _get_current_device_id() 

824 _require_view_device(view, desc._device_id, "TensorMapDescriptor._from_im2col") 

825  

826 tma_dt = _resolve_data_type(view, data_type) 

827 cdef int c_data_type_int = tma_dt 

828 cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int 

829  

830 cdef intptr_t global_address = view.ptr 

831 shape = view.shape 

832  

833 cdef int rank = len(shape) 

834 if rank < 3 or rank > 5: 

835 raise ValueError( 

836 f"Im2col tensor rank must be between 3 and 5, got {rank}") 

837  

838 cdef int n_spatial = rank - 2 

839 if len(pixel_box_lower_corner) != n_spatial: 

840 raise ValueError( 

841 f"pixel_box_lower_corner must have {n_spatial} elements " 

842 f"(rank - 2), got {len(pixel_box_lower_corner)}") 

843 if len(pixel_box_upper_corner) != n_spatial: 

844 raise ValueError( 

845 f"pixel_box_upper_corner must have {n_spatial} elements " 

846 f"(rank - 2), got {len(pixel_box_upper_corner)}") 

847  

848 element_strides = _validate_element_strides(element_strides, rank) 

849  

850 cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt] 

851 byte_strides = _compute_byte_strides(shape, view.strides, elem_size) 

852  

853 # Reverse all dimension arrays for column-major convention 

854 cdef uint64_t[5] c_global_dim 

855 cdef uint64_t[4] c_global_strides 

856 cdef uint32_t[5] c_element_strides 

857 cdef int[3] c_pixel_box_lower # max 3 spatial dims (rank 5 - 2) 

858 cdef int[3] c_pixel_box_upper 

859 cdef int i_c 

860  

861 for i_c in range(3): 

862 c_pixel_box_lower[i_c] = 0 

863 c_pixel_box_upper[i_c] = 0 

864  

865 for i_c in range(rank): 

866 c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c] 

867 c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c] 

868  

869 for i_c in range(rank - 1): 

870 c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c] 

871  

872 # Reverse spatial dimensions for lower/upper corners 

873 for i_c in range(n_spatial): 

874 c_pixel_box_lower[i_c] = <int>pixel_box_lower_corner[n_spatial - 1 - i_c] 

875 c_pixel_box_upper[i_c] = <int>pixel_box_upper_corner[n_spatial - 1 - i_c] 

876  

877 cdef uint32_t c_rank = <uint32_t>rank 

878 cdef uint32_t c_channels = <uint32_t>channels_per_pixel 

879 cdef uint32_t c_pixels = <uint32_t>pixels_per_column 

880 cdef int c_interleave_int = int(interleave) 

881 cdef int c_swizzle_int = int(swizzle) 

882 cdef int c_l2_promotion_int = int(l2_promotion) 

883 cdef int c_oob_fill_int = int(oob_fill) 

884 cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int 

885 cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int 

886 cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int 

887 cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int 

888  

889 with nogil: 

890 HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2col( 

891 &desc._tensor_map, 

892 c_data_type, 

893 c_rank, 

894 <void*>global_address, 

895 c_global_dim, 

896 c_global_strides, 

897 c_pixel_box_lower, 

898 c_pixel_box_upper, 

899 c_channels, 

900 c_pixels, 

901 c_element_strides, 

902 c_interleave, 

903 c_swizzle, 

904 c_l2_promotion, 

905 c_oob_fill, 

906 )) 

907  

908 desc._repr_info = { 

909 "method": "im2col", 

910 "rank": rank, 

911 "data_type": TensorMapDataType(tma_dt), 

912 "swizzle": swizzle, 

913 } 

914  

915 return desc 

916  

917 @classmethod 

918 def _from_im2col_wide(cls, view, pixel_box_lower_corner_width, pixel_box_upper_corner_width, 

919 channels_per_pixel, pixels_per_column, *, 

920 element_strides=None, 

921 data_type=None, 

922 interleave=TensorMapInterleave.NONE, 

923 mode=TensorMapIm2ColWideMode.W, 

924 swizzle=TensorMapSwizzle.SWIZZLE_128B, 

925 l2_promotion=TensorMapL2Promotion.NONE, 

926 oob_fill=TensorMapOOBFill.NONE): 

927 """Create an im2col-wide TMA descriptor from a validated view. 

928  

929 Im2col-wide layout loads elements exclusively along the W (width) 

930 dimension. This variant is supported on compute capability 10.0+ 

931 (Blackwell and later). 

932  

933 Parameters 

934 ---------- 

935 view : StridedMemoryView 

936 A device-accessible view with a 16-byte-aligned pointer. 

937 pixel_box_lower_corner_width : int 

938 Lower corner of the pixel bounding box along the W dimension. 

939 pixel_box_upper_corner_width : int 

940 Upper corner of the pixel bounding box along the W dimension. 

941 channels_per_pixel : int 

942 Number of channels per pixel. 

943 pixels_per_column : int 

944 Number of pixels per column. 

945 element_strides : tuple of int, optional 

946 Per-dimension element traversal strides. Default is all 1s. 

947 data_type : dtype-like or TensorMapDataType, optional 

948 Explicit dtype override. If ``None``, inferred from the tensor's 

949 dtype. Prefer NumPy or ``ml_dtypes`` dtype objects; the enum is 

950 accepted for compatibility. 

951 interleave : TensorMapInterleave 

952 Interleave layout. Default ``NONE``. 

953 mode : TensorMapIm2ColWideMode 

954 Im2col wide mode. Default ``W``. 

955 swizzle : TensorMapSwizzle 

956 Swizzle mode. Default ``SWIZZLE_128B``. 

957 l2_promotion : TensorMapL2Promotion 

958 L2 promotion mode. Default ``NONE``. 

959 oob_fill : TensorMapOOBFill 

960 Out-of-bounds fill mode. Default ``NONE``. 

961  

962 Returns 

963 ------- 

964 TensorMapDescriptor 

965  

966 Raises 

967 ------ 

968 ValueError 

969 If the tensor rank is outside [3, 5], the pointer is not 

970 16-byte aligned, or other constraints are violated. 

971 """ 

972 IF CUDA_CORE_BUILD_MAJOR < 13: 

973 raise RuntimeError( 

974 "TensorMapDescriptor._from_im2col_wide requires a CUDA 13+ build") 

975 ELSE: 

976 cdef TensorMapDescriptor desc = cls.__new__(cls) 

977  

978 _validate_tensor_map_view(view) 

979 desc._source_ref = view.exporting_obj 

980 desc._view_ref = view 

981 desc._context = _get_current_context_ptr() 

982 desc._device_id = _get_current_device_id() 

983 _require_view_device(view, desc._device_id, "TensorMapDescriptor._from_im2col_wide") 

984  

985 tma_dt = _resolve_data_type(view, data_type) 

986 cdef int c_data_type_int = tma_dt 

987 cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int 

988  

989 cdef intptr_t global_address = view.ptr 

990 shape = view.shape 

991  

992 cdef int rank = len(shape) 

993 if rank < 3 or rank > 5: 

994 raise ValueError( 

995 f"Im2col-wide tensor rank must be between 3 and 5, got {rank}") 

996  

997 element_strides = _validate_element_strides(element_strides, rank) 

998  

999 cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt] 

1000 byte_strides = _compute_byte_strides(shape, view.strides, elem_size) 

1001  

1002 # Reverse all dimension arrays for column-major convention 

1003 cdef uint64_t[5] c_global_dim 

1004 cdef uint64_t[4] c_global_strides 

1005 cdef uint32_t[5] c_element_strides 

1006 cdef int i_c 

1007  

1008 for i_c in range(rank): 

1009 c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c] 

1010 c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c] 

1011  

1012 for i_c in range(rank - 1): 

1013 c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c] 

1014  

1015 cdef uint32_t c_rank = <uint32_t>rank 

1016 cdef int c_lower_w = <int>pixel_box_lower_corner_width 

1017 cdef int c_upper_w = <int>pixel_box_upper_corner_width 

1018 cdef uint32_t c_channels = <uint32_t>channels_per_pixel 

1019 cdef uint32_t c_pixels = <uint32_t>pixels_per_column 

1020 cdef int c_interleave_int = int(interleave) 

1021 cdef int c_mode_int = int(mode) 

1022 cdef int c_swizzle_int = int(swizzle) 

1023 cdef int c_l2_promotion_int = int(l2_promotion) 

1024 cdef int c_oob_fill_int = int(oob_fill) 

1025 cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int 

1026 cdef cydriver.CUtensorMapIm2ColWideMode c_mode = <cydriver.CUtensorMapIm2ColWideMode>c_mode_int 

1027 cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int 

1028 cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int 

1029 cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int 

1030  

1031 with nogil: 

1032 HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2colWide( 

1033 &desc._tensor_map, 

1034 c_data_type, 

1035 c_rank, 

1036 <void*>global_address, 

1037 c_global_dim, 

1038 c_global_strides, 

1039 c_lower_w, 

1040 c_upper_w, 

1041 c_channels, 

1042 c_pixels, 

1043 c_element_strides, 

1044 c_interleave, 

1045 c_mode, 

1046 c_swizzle, 

1047 c_l2_promotion, 

1048 c_oob_fill, 

1049 )) 

1050  

1051 desc._repr_info = { 

1052 "method": "im2col_wide", 

1053 "rank": rank, 

1054 "data_type": TensorMapDataType(tma_dt), 

1055 "swizzle": swizzle, 

1056 } 

1057  

1058 return desc 

1059  

1060 def replace_address(self, tensor: object) -> None: 

1061 """Replace the global memory address in this tensor map descriptor. 

1062  

1063 This is useful when the tensor data has been reallocated but the 

1064 shape, strides, and other parameters remain the same. 

1065  

1066 Parameters 

1067 ---------- 

1068 tensor : object 

1069 Any object supporting DLPack or ``__cuda_array_interface__``, 

1070 or a :obj:`~cuda.core.StridedMemoryView`. Must refer to 

1071 device-accessible memory with a 16-byte-aligned pointer. 

1072 """ 

1073 self._check_context_compat() 

1074 view = _get_validated_view(tensor) 

1075 _require_view_device(view, self._device_id, "replace_address") 

1076  

1077 cdef intptr_t global_address = view.ptr 

1078  

1079 with nogil: 

1080 HANDLE_RETURN(cydriver.cuTensorMapReplaceAddress( 

1081 &self._tensor_map, 

1082 <void*>global_address, 

1083 )) 

1084  

1085 # Update the source reference only after the driver call succeeds, 

1086 # so we don't drop the old tensor (risking a dangling pointer in the 

1087 # CUtensorMap struct) if the call fails. 

1088 self._source_ref = view.exporting_obj 

1089 self._view_ref = view 

1090  

1091 def __repr__(self) -> str: 

1092 info = self._repr_info 

1093 if info is None: 

1094 return "TensorMapDescriptor()" 

1095 parts = [] 

1096 if "method" in info: 

1097 parts.append(info["method"]) 

1098 if "rank" in info: 

1099 parts.append(f"rank={info['rank']}") 

1100 if "data_type" in info: 

1101 parts.append(f"dtype={info['data_type'].name}") 

1102 if "swizzle" in info: 

1103 parts.append(f"swizzle={info['swizzle'].name}") 

1104 return f"TensorMapDescriptor({', '.join(parts)})"