Coverage for cuda / core / _kernel_arg_handler.pyx: 76.10%
205 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from cpython.mem cimport PyMem_Malloc, PyMem_Free
6from libc.stdint cimport (intptr_t,
7 int8_t, int16_t, int32_t, int64_t,
8 uint8_t, uint16_t, uint32_t, uint64_t,)
9from libcpp cimport bool as cpp_bool
10from libcpp.complex cimport complex as cpp_complex
11from libcpp cimport nullptr
12from libcpp cimport vector
14import ctypes
16import numpy
18from cuda.core._memory import Buffer
19from cuda.core._utils.cuda_utils import driver
20from cuda.bindings cimport cydriver
23ctypedef cpp_complex.complex[float] cpp_single_complex
24ctypedef cpp_complex.complex[double] cpp_double_complex
27# We need an identifier for fp16 for copying scalars on the host. This is a minimal
28# implementation borrowed from cuda_fp16.h.
29cdef extern from *:
30 """
31 #if __cplusplus >= 201103L
32 #define __CUDA_ALIGN__(n) alignas(n) /* C++11 kindly gives us a keyword for this */
33 #else
34 #if defined(__GNUC__)
35 #define __CUDA_ALIGN__(n) __attribute__ ((aligned(n)))
36 #elif defined(_MSC_VER)
37 #define __CUDA_ALIGN__(n) __declspec(align(n))
38 #else
39 #define __CUDA_ALIGN__(n)
40 #endif /* defined(__GNUC__) */
41 #endif /* __cplusplus >= 201103L */
43 typedef struct __CUDA_ALIGN__(2) {
44 /**
45 * Storage field contains bits representation of the \p half floating-point number.
46 */
47 unsigned short x;
48 } __half_raw;
49 """
50 ctypedef struct __half_raw:
51 unsigned short x
54ctypedef fused supported_type:
55 cpp_bool
56 int8_t
57 int16_t
58 int32_t
59 int64_t
60 uint8_t
61 uint16_t
62 uint32_t
63 uint64_t
64 __half_raw
65 float
66 double
67 intptr_t
68 cpp_single_complex
69 cpp_double_complex
72# cache ctypes/numpy type objects to avoid attribute access
73cdef object ctypes_bool = ctypes.c_bool
74cdef object ctypes_int8 = ctypes.c_int8
75cdef object ctypes_int16 = ctypes.c_int16
76cdef object ctypes_int32 = ctypes.c_int32
77cdef object ctypes_int64 = ctypes.c_int64
78cdef object ctypes_uint8 = ctypes.c_uint8
79cdef object ctypes_uint16 = ctypes.c_uint16
80cdef object ctypes_uint32 = ctypes.c_uint32
81cdef object ctypes_uint64 = ctypes.c_uint64
82cdef object ctypes_float = ctypes.c_float
83cdef object ctypes_double = ctypes.c_double
84cdef object numpy_bool = numpy.bool_
85cdef object numpy_int8 = numpy.int8
86cdef object numpy_int16 = numpy.int16
87cdef object numpy_int32 = numpy.int32
88cdef object numpy_int64 = numpy.int64
89cdef object numpy_uint8 = numpy.uint8
90cdef object numpy_uint16 = numpy.uint16
91cdef object numpy_uint32 = numpy.uint32
92cdef object numpy_uint64 = numpy.uint64
93cdef object numpy_float16 = numpy.float16
94cdef object numpy_float32 = numpy.float32
95cdef object numpy_float64 = numpy.float64
96cdef object numpy_complex64 = numpy.complex64
97cdef object numpy_complex128 = numpy.complex128
100# limitation due to cython/cython#534
101ctypedef void* voidptr
104# Cython can't infer the overload without at least one input argument with fused type
105cdef inline int prepare_arg(
106 vector.vector[void*]& data,
107 vector.vector[void*]& data_addresses,
108 arg, # important: keep it a Python object and don't cast
109 const size_t idx,
110 const supported_type* __unused=NULL) except -1:
111 cdef void* ptr = PyMem_Malloc(sizeof(supported_type)) 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
112 # note: this should also work once ctypes has complex support:
113 # python/cpython#121248
114 if supported_type is cpp_single_complex:
115 (<supported_type*>ptr)[0] = cpp_complex.complex[float](arg.real, arg.imag) 1P
116 elif supported_type is cpp_double_complex:
117 (<supported_type*>ptr)[0] = cpp_complex.complex[double](arg.real, arg.imag) 1O3
118 elif supported_type is __half_raw:
119 (<supported_type*>ptr).x = <int16_t>(arg.view(numpy_int16)) 1R
120 else:
121 (<supported_type*>ptr)[0] = <supported_type>(arg) 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
122 data_addresses[idx] = ptr # take the address to the scalar 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
123 data[idx] = ptr # for later dealloc 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
124 return 0 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
127cdef inline int prepare_ctypes_arg(
128 vector.vector[void*]& data,
129 vector.vector[void*]& data_addresses,
130 arg,
131 const size_t idx) except -1:
132 cdef object arg_type = type(arg) 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
133 if arg_type is ctypes_bool: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
134 return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx) 1jknoN
135 elif arg_type is ctypes_int8: 1rsabjklmtucdnopqvwxyefghiMzAJKHILGDEBCF
136 return prepare_arg[int8_t](data, data_addresses, arg.value, idx) 1M
137 elif arg_type is ctypes_int16: 1rsabjklmtucdnopqvwxyefghizAJKHILGDEBCF
138 return prepare_arg[int16_t](data, data_addresses, arg.value, idx) 1L
139 elif arg_type is ctypes_int32: 1rsabjklmtucdnopqvwxyefghizAJKHIGDEBCF
140 return prepare_arg[int32_t](data, data_addresses, arg.value, idx) 1JK
141 elif arg_type is ctypes_int64: 1rsabjklmtucdnopqvwxyefghizAHIGDEBCF
142 return prepare_arg[int64_t](data, data_addresses, arg.value, idx) 1HI
143 elif arg_type is ctypes_uint8: 1rsabjklmtucdnopqvwxyefghizAGDEBCF
144 return prepare_arg[uint8_t](data, data_addresses, arg.value, idx) 1G
145 elif arg_type is ctypes_uint16: 1rsabjklmtucdnopqvwxyefghizADEBCF
146 return prepare_arg[uint16_t](data, data_addresses, arg.value, idx) 1F
147 elif arg_type is ctypes_uint32: 1rsabjklmtucdnopqvwxyefghizADEBC
148 return prepare_arg[uint32_t](data, data_addresses, arg.value, idx) 1DE
149 elif arg_type is ctypes_uint64: 1rsabjklmtucdnopqvwxyefghizABC
150 return prepare_arg[uint64_t](data, data_addresses, arg.value, idx) 1BC
151 elif arg_type is ctypes_float: 1rsabjklmtucdnopqvwxyefghizA
152 return prepare_arg[float](data, data_addresses, arg.value, idx) 1A
153 elif arg_type is ctypes_double: 1rsabjklmtucdnopqvwxyefghiz
154 return prepare_arg[double](data, data_addresses, arg.value, idx) 1z
155 else:
156 # If no exact types are found, fallback to slower `isinstance` check
157 if isinstance(arg, ctypes_bool): 1rsabjklmtucdnopqvwxyefghi
158 return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
159 elif isinstance(arg, ctypes_int8): 1rsabjklmtucdnopqvwxyefghi
160 return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
161 elif isinstance(arg, ctypes_int16): 1rsabjklmtucdnopqvwxyefghi
162 return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
163 elif isinstance(arg, ctypes_int32): 1rsabjklmtucdnopqvwxyefghi
164 return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
165 elif isinstance(arg, ctypes_int64): 1rsabjklmtucdnopqvwxyefghi
166 return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
167 elif isinstance(arg, ctypes_uint8): 1rsabjklmtucdnopqvwxyefghi
168 return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
169 elif isinstance(arg, ctypes_uint16): 1rsabjklmtucdnopqvwxyefghi
170 return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
171 elif isinstance(arg, ctypes_uint32): 1rsabjklmtucdnopqvwxyefghi
172 return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
173 elif isinstance(arg, ctypes_uint64): 1rsabjklmtucdnopqvwxyefghi
174 return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
175 elif isinstance(arg, ctypes_float): 1rsabjklmtucdnopqvwxyefghi
176 return prepare_arg[float](data, data_addresses, arg.value, idx)
177 elif isinstance(arg, ctypes_double): 1rsabjklmtucdnopqvwxyefghi
178 return prepare_arg[double](data, data_addresses, arg.value, idx)
179 else:
180 return 1 1rsabjklmtucdnopqvwxyefghi
183cdef inline int prepare_numpy_arg(
184 vector.vector[void*]& data,
185 vector.vector[void*]& data_addresses,
186 arg,
187 const size_t idx) except -1:
188 cdef object arg_type = type(arg) 1rsabjklmtucdnopqvwxyefghi2NMzAJKHILGDEBCFPORSQ10ZYWVUT
189 if arg_type is numpy_bool: 1rsabjklmtucdnopqvwxyefghi2NMzAJKHILGDEBCFPORSQ10ZYWVUT
190 return prepare_arg[cpp_bool](data, data_addresses, arg, idx) 1lmpq2
191 elif arg_type is numpy_int8: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPORSQ10ZYWVUT
192 return prepare_arg[int8_t](data, data_addresses, arg, idx) 11
193 elif arg_type is numpy_int16: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPORSQ0ZYWVUT
194 return prepare_arg[int16_t](data, data_addresses, arg, idx) 10
195 elif arg_type is numpy_int32: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPORSQZYWVUT
196 return prepare_arg[int32_t](data, data_addresses, arg, idx) 1Z
197 elif arg_type is numpy_int64: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPORSQYWVUT
198 return prepare_arg[int64_t](data, data_addresses, arg, idx) 1Y
199 elif arg_type is numpy_uint8: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPORSQWVUT
200 return prepare_arg[uint8_t](data, data_addresses, arg, idx) 1W
201 elif arg_type is numpy_uint16: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPORSQVUT
202 return prepare_arg[uint16_t](data, data_addresses, arg, idx) 1V
203 elif arg_type is numpy_uint32: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPORSQUT
204 return prepare_arg[uint32_t](data, data_addresses, arg, idx) 1U
205 elif arg_type is numpy_uint64: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPORSQT
206 return prepare_arg[uint64_t](data, data_addresses, arg, idx) 1T
207 elif arg_type is numpy_float16: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPORSQ
208 return prepare_arg[__half_raw](data, data_addresses, arg, idx) 1R
209 elif arg_type is numpy_float32: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPOSQ
210 return prepare_arg[float](data, data_addresses, arg, idx) 1S
211 elif arg_type is numpy_float64: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPOQ
212 return prepare_arg[double](data, data_addresses, arg, idx) 1Q
213 elif arg_type is numpy_complex64: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFPO
214 return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx) 1P
215 elif arg_type is numpy_complex128: 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCFO
216 return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx) 1O
217 else:
218 # If no exact types are found, fallback to slower `isinstance` check
219 if isinstance(arg, numpy_bool): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
220 return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
221 elif isinstance(arg, numpy_int8): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
222 return prepare_arg[int8_t](data, data_addresses, arg, idx)
223 elif isinstance(arg, numpy_int16): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
224 return prepare_arg[int16_t](data, data_addresses, arg, idx)
225 elif isinstance(arg, numpy_int32): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
226 return prepare_arg[int32_t](data, data_addresses, arg, idx)
227 elif isinstance(arg, numpy_int64): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
228 return prepare_arg[int64_t](data, data_addresses, arg, idx)
229 elif isinstance(arg, numpy_uint8): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
230 return prepare_arg[uint8_t](data, data_addresses, arg, idx)
231 elif isinstance(arg, numpy_uint16): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
232 return prepare_arg[uint16_t](data, data_addresses, arg, idx)
233 elif isinstance(arg, numpy_uint32): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
234 return prepare_arg[uint32_t](data, data_addresses, arg, idx)
235 elif isinstance(arg, numpy_uint64): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
236 return prepare_arg[uint64_t](data, data_addresses, arg, idx)
237 elif isinstance(arg, numpy_float16): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
238 return prepare_arg[__half_raw](data, data_addresses, arg, idx)
239 elif isinstance(arg, numpy_float32): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
240 return prepare_arg[float](data, data_addresses, arg, idx)
241 elif isinstance(arg, numpy_float64): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
242 return prepare_arg[double](data, data_addresses, arg, idx)
243 elif isinstance(arg, numpy_complex64): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
244 return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
245 elif isinstance(arg, numpy_complex128): 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
246 return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
247 else:
248 return 1 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
251cdef class ParamHolder:
253 def __init__(self, kernel_args):
254 if len(kernel_args) == 0: 1*=+?@,[456789!#$%'rsabjklmtucdnopqvwxyefgh]^i-./_:`{)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
255 self.ptr = 0 1=?@[efgh]^i_`{
256 return 1=?@[efgh]^i_`{
258 cdef size_t n_args = len(kernel_args) 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
259 cdef size_t i
260 cdef int not_prepared
261 cdef object arg_type
262 self.data = vector.vector[voidptr](n_args, nullptr) 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
263 self.data_addresses = vector.vector[voidptr](n_args) 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
264 for i, arg in enumerate(kernel_args): 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
265 arg_type = type(arg) 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
266 if arg_type is Buffer: 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
267 # we need the address of where the actual buffer address is stored
268 if type(arg.handle) is int: 1456789!#$%';
269 # see note below on handling int arguments
270 prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) 1456789!#$%';
271 continue 1456789!#$%';
272 else:
273 # it's a CUdeviceptr:
274 self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
275 continue
276 elif arg_type is bool: 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT
277 prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) 1abcdi)
278 continue 1abcdi)
279 elif arg_type is int: 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT
280 # Here's the dilemma: We want to have a fast path to pass in Python
281 # integers as pointer addresses, but one could also (mistakenly) pass
282 # it with the intention of passing a scalar integer. It's a mistake
283 # bacause a Python int is ambiguous (arbitrary width). Our judgement
284 # call here is to treat it as a pointer address, without any warning!
285 prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) 1*+,456789!#$%'rsabjklmtucdnopqvwxyefgh-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT
286 continue 1*+,456789!#$%'rsabjklmtucdnopqvwxyefgh-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT
287 elif arg_type is float: 1rsabjklmtucdnopqvwxyefghi2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT
288 prepare_arg[double](self.data, self.data_addresses, arg, i) 1(
289 continue 1(
290 elif arg_type is complex: 1rsabjklmtucdnopqvwxyefghi2NMzAJKHILGDEBCFPO3RSQ10ZYWVUT
291 prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) 13
292 continue 13
294 not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i) 1rsabjklmtucdnopqvwxyefghi2NMzAJKHILGDEBCFPORSQ10ZYWVUT
295 if not_prepared: 1rsabjklmtucdnopqvwxyefghi2NMzAJKHILGDEBCFPORSQ10ZYWVUT
296 not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i) 1rsabjklmtucdnopqvwxyefghiNMzAJKHILGDEBCF
297 if not_prepared: 1rsabjklmtucdnopqvwxyefghi2NMzAJKHILGDEBCFPORSQ10ZYWVUT
298 # TODO: revisit this treatment if we decide to cythonize cuda.core
299 if arg_type is driver.CUgraphConditionalHandle: 1rsabjklmtucdnopqvwxyefghi
300 prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, <intptr_t>int(arg), i) 1rsabjklmtucdnopqvwxyefghi
301 continue 1rsabjklmtucdnopqvwxyefghi
302 # If no exact types are found, fallback to slower `isinstance` check
303 elif isinstance(arg, Buffer):
304 if isinstance(arg.handle, int):
305 prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
306 continue
307 else:
308 self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
309 continue
310 elif isinstance(arg, bool):
311 prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
312 continue
313 elif isinstance(arg, int):
314 prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
315 continue
316 elif isinstance(arg, float):
317 prepare_arg[double](self.data, self.data_addresses, arg, i)
318 continue
319 elif isinstance(arg, complex):
320 prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
321 continue
322 elif isinstance(arg, driver.CUgraphConditionalHandle):
323 prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg, i)
324 continue
325 # TODO: support ctypes/numpy struct
326 raise TypeError("the argument is of unsupported type: " + str(type(arg)))
328 self.kernel_args = kernel_args 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
329 self.ptr = <intptr_t>self.data_addresses.data() 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
331 def __dealloc__(self):
332 for data in self.data: 1*=+?@,[456789!#$%'rsabjklmtucdnopqvwxyefgh]^i-./_:`{)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
333 if data: 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;
334 PyMem_Free(data) 1*+,456789!#$%'rsabjklmtucdnopqvwxyefghi-./:)2NMzAJKHILGDEBCFPO3RSQ(10ZYWVUT;