Coverage for cuda / core / experimental / _kernel_arg_handler.pyx: 76%
205 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from cpython.mem cimport PyMem_Malloc, PyMem_Free
6from libc.stdint cimport (intptr_t,
7 int8_t, int16_t, int32_t, int64_t,
8 uint8_t, uint16_t, uint32_t, uint64_t,)
9from libcpp cimport bool as cpp_bool
10from libcpp.complex cimport complex as cpp_complex
11from libcpp cimport nullptr
12from libcpp cimport vector
14import ctypes
16import numpy
18from cuda.core.experimental._memory import Buffer
19from cuda.core.experimental._utils.cuda_utils import driver
20from cuda.bindings cimport cydriver
23ctypedef cpp_complex.complex[float] cpp_single_complex
24ctypedef cpp_complex.complex[double] cpp_double_complex
27# We need an identifier for fp16 for copying scalars on the host. This is a minimal
28# implementation borrowed from cuda_fp16.h.
29cdef extern from *:
30 """
31 #if __cplusplus >= 201103L
32 #define __CUDA_ALIGN__(n) alignas(n) /* C++11 kindly gives us a keyword for this */
33 #else
34 #if defined(__GNUC__)
35 #define __CUDA_ALIGN__(n) __attribute__ ((aligned(n)))
36 #elif defined(_MSC_VER)
37 #define __CUDA_ALIGN__(n) __declspec(align(n))
38 #else
39 #define __CUDA_ALIGN__(n)
40 #endif /* defined(__GNUC__) */
41 #endif /* __cplusplus >= 201103L */
43 typedef struct __CUDA_ALIGN__(2) {
44 /**
45 * Storage field contains bits representation of the \p half floating-point number.
46 */
47 unsigned short x;
48 } __half_raw;
49 """
50 ctypedef struct __half_raw:
51 unsigned short x
54ctypedef fused supported_type:
55 cpp_bool
56 int8_t
57 int16_t
58 int32_t
59 int64_t
60 uint8_t
61 uint16_t
62 uint32_t
63 uint64_t
64 __half_raw
65 float
66 double
67 intptr_t
68 cpp_single_complex
69 cpp_double_complex
72# cache ctypes/numpy type objects to avoid attribute access
73cdef object ctypes_bool = ctypes.c_bool
74cdef object ctypes_int8 = ctypes.c_int8
75cdef object ctypes_int16 = ctypes.c_int16
76cdef object ctypes_int32 = ctypes.c_int32
77cdef object ctypes_int64 = ctypes.c_int64
78cdef object ctypes_uint8 = ctypes.c_uint8
79cdef object ctypes_uint16 = ctypes.c_uint16
80cdef object ctypes_uint32 = ctypes.c_uint32
81cdef object ctypes_uint64 = ctypes.c_uint64
82cdef object ctypes_float = ctypes.c_float
83cdef object ctypes_double = ctypes.c_double
84cdef object numpy_bool = numpy.bool_
85cdef object numpy_int8 = numpy.int8
86cdef object numpy_int16 = numpy.int16
87cdef object numpy_int32 = numpy.int32
88cdef object numpy_int64 = numpy.int64
89cdef object numpy_uint8 = numpy.uint8
90cdef object numpy_uint16 = numpy.uint16
91cdef object numpy_uint32 = numpy.uint32
92cdef object numpy_uint64 = numpy.uint64
93cdef object numpy_float16 = numpy.float16
94cdef object numpy_float32 = numpy.float32
95cdef object numpy_float64 = numpy.float64
96cdef object numpy_complex64 = numpy.complex64
97cdef object numpy_complex128 = numpy.complex128
100# limitation due to cython/cython#534
101ctypedef void* voidptr
104# Cython can't infer the overload without at least one input argument with fused type
105cdef inline int prepare_arg(
106 vector.vector[void*]& data,
107 vector.vector[void*]& data_addresses,
108 arg, # important: keep it a Python object and don't cast
109 const size_t idx,
110 const supported_type* __unused=NULL) except -1:
111 cdef void* ptr = PyMem_Malloc(sizeof(supported_type))
112 # note: this should also work once ctypes has complex support:
113 # python/cpython#121248
114 if supported_type is cpp_single_complex:
115 (<supported_type*>ptr)[0] = cpp_complex.complex[float](arg.real, arg.imag)
116 elif supported_type is cpp_double_complex:
117 (<supported_type*>ptr)[0] = cpp_complex.complex[double](arg.real, arg.imag)
118 elif supported_type is __half_raw:
119 (<supported_type*>ptr).x = <int16_t>(arg.view(numpy_int16))
120 else:
121 (<supported_type*>ptr)[0] = <supported_type>(arg)
122 data_addresses[idx] = ptr # take the address to the scalar
123 data[idx] = ptr # for later dealloc
124 return 0
127cdef inline int prepare_ctypes_arg(
128 vector.vector[void*]& data,
129 vector.vector[void*]& data_addresses,
130 arg,
131 const size_t idx) except -1:
132 cdef object arg_type = type(arg)
133 if arg_type is ctypes_bool:
134 return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
135 elif arg_type is ctypes_int8:
136 return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
137 elif arg_type is ctypes_int16:
138 return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
139 elif arg_type is ctypes_int32:
140 return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
141 elif arg_type is ctypes_int64:
142 return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
143 elif arg_type is ctypes_uint8:
144 return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
145 elif arg_type is ctypes_uint16:
146 return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
147 elif arg_type is ctypes_uint32:
148 return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
149 elif arg_type is ctypes_uint64:
150 return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
151 elif arg_type is ctypes_float:
152 return prepare_arg[float](data, data_addresses, arg.value, idx)
153 elif arg_type is ctypes_double:
154 return prepare_arg[double](data, data_addresses, arg.value, idx)
155 else:
156 # If no exact types are found, fallback to slower `isinstance` check
157 if isinstance(arg, ctypes_bool):
158 return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
159 elif isinstance(arg, ctypes_int8):
160 return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
161 elif isinstance(arg, ctypes_int16):
162 return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
163 elif isinstance(arg, ctypes_int32):
164 return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
165 elif isinstance(arg, ctypes_int64):
166 return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
167 elif isinstance(arg, ctypes_uint8):
168 return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
169 elif isinstance(arg, ctypes_uint16):
170 return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
171 elif isinstance(arg, ctypes_uint32):
172 return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
173 elif isinstance(arg, ctypes_uint64):
174 return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
175 elif isinstance(arg, ctypes_float):
176 return prepare_arg[float](data, data_addresses, arg.value, idx)
177 elif isinstance(arg, ctypes_double):
178 return prepare_arg[double](data, data_addresses, arg.value, idx)
179 else:
180 return 1
183cdef inline int prepare_numpy_arg(
184 vector.vector[void*]& data,
185 vector.vector[void*]& data_addresses,
186 arg,
187 const size_t idx) except -1:
188 cdef object arg_type = type(arg)
189 if arg_type is numpy_bool:
190 return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
191 elif arg_type is numpy_int8:
192 return prepare_arg[int8_t](data, data_addresses, arg, idx)
193 elif arg_type is numpy_int16:
194 return prepare_arg[int16_t](data, data_addresses, arg, idx)
195 elif arg_type is numpy_int32:
196 return prepare_arg[int32_t](data, data_addresses, arg, idx)
197 elif arg_type is numpy_int64:
198 return prepare_arg[int64_t](data, data_addresses, arg, idx)
199 elif arg_type is numpy_uint8:
200 return prepare_arg[uint8_t](data, data_addresses, arg, idx)
201 elif arg_type is numpy_uint16:
202 return prepare_arg[uint16_t](data, data_addresses, arg, idx)
203 elif arg_type is numpy_uint32:
204 return prepare_arg[uint32_t](data, data_addresses, arg, idx)
205 elif arg_type is numpy_uint64:
206 return prepare_arg[uint64_t](data, data_addresses, arg, idx)
207 elif arg_type is numpy_float16:
208 return prepare_arg[__half_raw](data, data_addresses, arg, idx)
209 elif arg_type is numpy_float32:
210 return prepare_arg[float](data, data_addresses, arg, idx)
211 elif arg_type is numpy_float64:
212 return prepare_arg[double](data, data_addresses, arg, idx)
213 elif arg_type is numpy_complex64:
214 return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
215 elif arg_type is numpy_complex128:
216 return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
217 else:
218 # If no exact types are found, fallback to slower `isinstance` check
219 if isinstance(arg, numpy_bool):
220 return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
221 elif isinstance(arg, numpy_int8):
222 return prepare_arg[int8_t](data, data_addresses, arg, idx)
223 elif isinstance(arg, numpy_int16):
224 return prepare_arg[int16_t](data, data_addresses, arg, idx)
225 elif isinstance(arg, numpy_int32):
226 return prepare_arg[int32_t](data, data_addresses, arg, idx)
227 elif isinstance(arg, numpy_int64):
228 return prepare_arg[int64_t](data, data_addresses, arg, idx)
229 elif isinstance(arg, numpy_uint8):
230 return prepare_arg[uint8_t](data, data_addresses, arg, idx)
231 elif isinstance(arg, numpy_uint16):
232 return prepare_arg[uint16_t](data, data_addresses, arg, idx)
233 elif isinstance(arg, numpy_uint32):
234 return prepare_arg[uint32_t](data, data_addresses, arg, idx)
235 elif isinstance(arg, numpy_uint64):
236 return prepare_arg[uint64_t](data, data_addresses, arg, idx)
237 elif isinstance(arg, numpy_float16):
238 return prepare_arg[__half_raw](data, data_addresses, arg, idx)
239 elif isinstance(arg, numpy_float32):
240 return prepare_arg[float](data, data_addresses, arg, idx)
241 elif isinstance(arg, numpy_float64):
242 return prepare_arg[double](data, data_addresses, arg, idx)
243 elif isinstance(arg, numpy_complex64):
244 return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
245 elif isinstance(arg, numpy_complex128):
246 return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
247 else:
248 return 1
251cdef class ParamHolder:
253 cdef:
254 vector.vector[void*] data
255 vector.vector[void*] data_addresses
256 object kernel_args
257 readonly intptr_t ptr
259 def __init__(self, kernel_args):
260 if len(kernel_args) == 0:
261 self.ptr = 0
262 return
264 cdef size_t n_args = len(kernel_args)
265 cdef size_t i
266 cdef int not_prepared
267 cdef object arg_type
268 self.data = vector.vector[voidptr](n_args, nullptr)
269 self.data_addresses = vector.vector[voidptr](n_args)
270 for i, arg in enumerate(kernel_args):
271 arg_type = type(arg)
272 if arg_type is Buffer:
273 # we need the address of where the actual buffer address is stored
274 if type(arg.handle) is int:
275 # see note below on handling int arguments
276 prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
277 continue
278 else:
279 # it's a CUdeviceptr:
280 self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
281 continue
282 elif arg_type is bool:
283 prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
284 continue
285 elif arg_type is int:
286 # Here's the dilemma: We want to have a fast path to pass in Python
287 # integers as pointer addresses, but one could also (mistakenly) pass
288 # it with the intention of passing a scalar integer. It's a mistake
289 # bacause a Python int is ambiguous (arbitrary width). Our judgement
290 # call here is to treat it as a pointer address, without any warning!
291 prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
292 continue
293 elif arg_type is float:
294 prepare_arg[double](self.data, self.data_addresses, arg, i)
295 continue
296 elif arg_type is complex:
297 prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
298 continue
300 not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
301 if not_prepared:
302 not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i)
303 if not_prepared:
304 # TODO: revisit this treatment if we decide to cythonize cuda.core
305 if arg_type is driver.CUgraphConditionalHandle:
306 prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, <intptr_t>int(arg), i)
307 continue
308 # If no exact types are found, fallback to slower `isinstance` check
309 elif isinstance(arg, Buffer):
310 if isinstance(arg.handle, int):
311 prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
312 continue
313 else:
314 self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
315 continue
316 elif isinstance(arg, bool):
317 prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
318 continue
319 elif isinstance(arg, int):
320 prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
321 continue
322 elif isinstance(arg, float):
323 prepare_arg[double](self.data, self.data_addresses, arg, i)
324 continue
325 elif isinstance(arg, complex):
326 prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
327 continue
328 elif isinstance(arg, driver.CUgraphConditionalHandle):
329 prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg, i)
330 continue
331 # TODO: support ctypes/numpy struct
332 raise TypeError("the argument is of unsupported type: " + str(type(arg)))
334 self.kernel_args = kernel_args
335 self.ptr = <intptr_t>self.data_addresses.data()
337 def __dealloc__(self):
338 for data in self.data:
339 if data:
340 PyMem_Free(data)