Coverage for cuda / core / experimental / _kernel_arg_handler.pyx: 76%

205 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-10 01:19 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5from cpython.mem cimport PyMem_Malloc, PyMem_Free 

6from libc.stdint cimport (intptr_t, 

7 int8_t, int16_t, int32_t, int64_t, 

8 uint8_t, uint16_t, uint32_t, uint64_t,) 

9from libcpp cimport bool as cpp_bool 

10from libcpp.complex cimport complex as cpp_complex 

11from libcpp cimport nullptr 

12from libcpp cimport vector 

13  

14import ctypes 

15  

16import numpy 

17  

18from cuda.core.experimental._memory import Buffer 

19from cuda.core.experimental._utils.cuda_utils import driver 

20from cuda.bindings cimport cydriver 

21  

22  

23ctypedef cpp_complex.complex[float] cpp_single_complex 

24ctypedef cpp_complex.complex[double] cpp_double_complex 

25  

26  

27# We need an identifier for fp16 for copying scalars on the host. This is a minimal 

28# implementation borrowed from cuda_fp16.h. 

29cdef extern from *: 

30 """ 

31 #if __cplusplus >= 201103L 

32 #define __CUDA_ALIGN__(n) alignas(n) /* C++11 kindly gives us a keyword for this */ 

33 #else 

34 #if defined(__GNUC__) 

35 #define __CUDA_ALIGN__(n) __attribute__ ((aligned(n))) 

36 #elif defined(_MSC_VER) 

37 #define __CUDA_ALIGN__(n) __declspec(align(n)) 

38 #else 

39 #define __CUDA_ALIGN__(n) 

40 #endif /* defined(__GNUC__) */ 

41 #endif /* __cplusplus >= 201103L */ 

42  

43 typedef struct __CUDA_ALIGN__(2) { 

44 /** 

45 * Storage field contains bits representation of the \p half floating-point number. 

46 */ 

47 unsigned short x; 

48 } __half_raw; 

49 """ 

50 ctypedef struct __half_raw: 

51 unsigned short x 

52  

53  

54ctypedef fused supported_type: 

55 cpp_bool 

56 int8_t 

57 int16_t 

58 int32_t 

59 int64_t 

60 uint8_t 

61 uint16_t 

62 uint32_t 

63 uint64_t 

64 __half_raw 

65 float 

66 double 

67 intptr_t 

68 cpp_single_complex 

69 cpp_double_complex 

70  

71  

72# cache ctypes/numpy type objects to avoid attribute access 

73cdef object ctypes_bool = ctypes.c_bool 

74cdef object ctypes_int8 = ctypes.c_int8 

75cdef object ctypes_int16 = ctypes.c_int16 

76cdef object ctypes_int32 = ctypes.c_int32 

77cdef object ctypes_int64 = ctypes.c_int64 

78cdef object ctypes_uint8 = ctypes.c_uint8 

79cdef object ctypes_uint16 = ctypes.c_uint16 

80cdef object ctypes_uint32 = ctypes.c_uint32 

81cdef object ctypes_uint64 = ctypes.c_uint64 

82cdef object ctypes_float = ctypes.c_float 

83cdef object ctypes_double = ctypes.c_double 

84cdef object numpy_bool = numpy.bool_ 

85cdef object numpy_int8 = numpy.int8 

86cdef object numpy_int16 = numpy.int16 

87cdef object numpy_int32 = numpy.int32 

88cdef object numpy_int64 = numpy.int64 

89cdef object numpy_uint8 = numpy.uint8 

90cdef object numpy_uint16 = numpy.uint16 

91cdef object numpy_uint32 = numpy.uint32 

92cdef object numpy_uint64 = numpy.uint64 

93cdef object numpy_float16 = numpy.float16 

94cdef object numpy_float32 = numpy.float32 

95cdef object numpy_float64 = numpy.float64 

96cdef object numpy_complex64 = numpy.complex64 

97cdef object numpy_complex128 = numpy.complex128 

98  

99  

100# limitation due to cython/cython#534 

101ctypedef void* voidptr 

102  

103  

104# Cython can't infer the overload without at least one input argument with fused type 

105cdef inline int prepare_arg( 

106 vector.vector[void*]& data, 

107 vector.vector[void*]& data_addresses, 

108 arg, # important: keep it a Python object and don't cast 

109 const size_t idx, 

110 const supported_type* __unused=NULL) except -1: 

111 cdef void* ptr = PyMem_Malloc(sizeof(supported_type)) 

112 # note: this should also work once ctypes has complex support: 

113 # python/cpython#121248 

114 if supported_type is cpp_single_complex: 

115 (<supported_type*>ptr)[0] = cpp_complex.complex[float](arg.real, arg.imag) 

116 elif supported_type is cpp_double_complex: 

117 (<supported_type*>ptr)[0] = cpp_complex.complex[double](arg.real, arg.imag) 

118 elif supported_type is __half_raw: 

119 (<supported_type*>ptr).x = <int16_t>(arg.view(numpy_int16)) 

120 else: 

121 (<supported_type*>ptr)[0] = <supported_type>(arg) 

122 data_addresses[idx] = ptr # take the address to the scalar 

123 data[idx] = ptr # for later dealloc 

124 return 0 

125  

126  

127cdef inline int prepare_ctypes_arg( 

128 vector.vector[void*]& data, 

129 vector.vector[void*]& data_addresses, 

130 arg, 

131 const size_t idx) except -1: 

132 cdef object arg_type = type(arg) 

133 if arg_type is ctypes_bool: 

134 return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx) 

135 elif arg_type is ctypes_int8: 

136 return prepare_arg[int8_t](data, data_addresses, arg.value, idx) 

137 elif arg_type is ctypes_int16: 

138 return prepare_arg[int16_t](data, data_addresses, arg.value, idx) 

139 elif arg_type is ctypes_int32: 

140 return prepare_arg[int32_t](data, data_addresses, arg.value, idx) 

141 elif arg_type is ctypes_int64: 

142 return prepare_arg[int64_t](data, data_addresses, arg.value, idx) 

143 elif arg_type is ctypes_uint8: 

144 return prepare_arg[uint8_t](data, data_addresses, arg.value, idx) 

145 elif arg_type is ctypes_uint16: 

146 return prepare_arg[uint16_t](data, data_addresses, arg.value, idx) 

147 elif arg_type is ctypes_uint32: 

148 return prepare_arg[uint32_t](data, data_addresses, arg.value, idx) 

149 elif arg_type is ctypes_uint64: 

150 return prepare_arg[uint64_t](data, data_addresses, arg.value, idx) 

151 elif arg_type is ctypes_float: 

152 return prepare_arg[float](data, data_addresses, arg.value, idx) 

153 elif arg_type is ctypes_double: 

154 return prepare_arg[double](data, data_addresses, arg.value, idx) 

155 else: 

156 # If no exact types are found, fallback to slower `isinstance` check 

157 if isinstance(arg, ctypes_bool): 

158 return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx) 

159 elif isinstance(arg, ctypes_int8): 

160 return prepare_arg[int8_t](data, data_addresses, arg.value, idx) 

161 elif isinstance(arg, ctypes_int16): 

162 return prepare_arg[int16_t](data, data_addresses, arg.value, idx) 

163 elif isinstance(arg, ctypes_int32): 

164 return prepare_arg[int32_t](data, data_addresses, arg.value, idx) 

165 elif isinstance(arg, ctypes_int64): 

166 return prepare_arg[int64_t](data, data_addresses, arg.value, idx) 

167 elif isinstance(arg, ctypes_uint8): 

168 return prepare_arg[uint8_t](data, data_addresses, arg.value, idx) 

169 elif isinstance(arg, ctypes_uint16): 

170 return prepare_arg[uint16_t](data, data_addresses, arg.value, idx) 

171 elif isinstance(arg, ctypes_uint32): 

172 return prepare_arg[uint32_t](data, data_addresses, arg.value, idx) 

173 elif isinstance(arg, ctypes_uint64): 

174 return prepare_arg[uint64_t](data, data_addresses, arg.value, idx) 

175 elif isinstance(arg, ctypes_float): 

176 return prepare_arg[float](data, data_addresses, arg.value, idx) 

177 elif isinstance(arg, ctypes_double): 

178 return prepare_arg[double](data, data_addresses, arg.value, idx) 

179 else: 

180 return 1 

181  

182  

183cdef inline int prepare_numpy_arg( 

184 vector.vector[void*]& data, 

185 vector.vector[void*]& data_addresses, 

186 arg, 

187 const size_t idx) except -1: 

188 cdef object arg_type = type(arg) 

189 if arg_type is numpy_bool: 

190 return prepare_arg[cpp_bool](data, data_addresses, arg, idx) 

191 elif arg_type is numpy_int8: 

192 return prepare_arg[int8_t](data, data_addresses, arg, idx) 

193 elif arg_type is numpy_int16: 

194 return prepare_arg[int16_t](data, data_addresses, arg, idx) 

195 elif arg_type is numpy_int32: 

196 return prepare_arg[int32_t](data, data_addresses, arg, idx) 

197 elif arg_type is numpy_int64: 

198 return prepare_arg[int64_t](data, data_addresses, arg, idx) 

199 elif arg_type is numpy_uint8: 

200 return prepare_arg[uint8_t](data, data_addresses, arg, idx) 

201 elif arg_type is numpy_uint16: 

202 return prepare_arg[uint16_t](data, data_addresses, arg, idx) 

203 elif arg_type is numpy_uint32: 

204 return prepare_arg[uint32_t](data, data_addresses, arg, idx) 

205 elif arg_type is numpy_uint64: 

206 return prepare_arg[uint64_t](data, data_addresses, arg, idx) 

207 elif arg_type is numpy_float16: 

208 return prepare_arg[__half_raw](data, data_addresses, arg, idx) 

209 elif arg_type is numpy_float32: 

210 return prepare_arg[float](data, data_addresses, arg, idx) 

211 elif arg_type is numpy_float64: 

212 return prepare_arg[double](data, data_addresses, arg, idx) 

213 elif arg_type is numpy_complex64: 

214 return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx) 

215 elif arg_type is numpy_complex128: 

216 return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx) 

217 else: 

218 # If no exact types are found, fallback to slower `isinstance` check 

219 if isinstance(arg, numpy_bool): 

220 return prepare_arg[cpp_bool](data, data_addresses, arg, idx) 

221 elif isinstance(arg, numpy_int8): 

222 return prepare_arg[int8_t](data, data_addresses, arg, idx) 

223 elif isinstance(arg, numpy_int16): 

224 return prepare_arg[int16_t](data, data_addresses, arg, idx) 

225 elif isinstance(arg, numpy_int32): 

226 return prepare_arg[int32_t](data, data_addresses, arg, idx) 

227 elif isinstance(arg, numpy_int64): 

228 return prepare_arg[int64_t](data, data_addresses, arg, idx) 

229 elif isinstance(arg, numpy_uint8): 

230 return prepare_arg[uint8_t](data, data_addresses, arg, idx) 

231 elif isinstance(arg, numpy_uint16): 

232 return prepare_arg[uint16_t](data, data_addresses, arg, idx) 

233 elif isinstance(arg, numpy_uint32): 

234 return prepare_arg[uint32_t](data, data_addresses, arg, idx) 

235 elif isinstance(arg, numpy_uint64): 

236 return prepare_arg[uint64_t](data, data_addresses, arg, idx) 

237 elif isinstance(arg, numpy_float16): 

238 return prepare_arg[__half_raw](data, data_addresses, arg, idx) 

239 elif isinstance(arg, numpy_float32): 

240 return prepare_arg[float](data, data_addresses, arg, idx) 

241 elif isinstance(arg, numpy_float64): 

242 return prepare_arg[double](data, data_addresses, arg, idx) 

243 elif isinstance(arg, numpy_complex64): 

244 return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx) 

245 elif isinstance(arg, numpy_complex128): 

246 return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx) 

247 else: 

248 return 1 

249  

250  

251cdef class ParamHolder: 

252  

253 cdef: 

254 vector.vector[void*] data 

255 vector.vector[void*] data_addresses 

256 object kernel_args 

257 readonly intptr_t ptr 

258  

259 def __init__(self, kernel_args): 

260 if len(kernel_args) == 0: 

261 self.ptr = 0 

262 return 

263  

264 cdef size_t n_args = len(kernel_args) 

265 cdef size_t i 

266 cdef int not_prepared 

267 cdef object arg_type 

268 self.data = vector.vector[voidptr](n_args, nullptr) 

269 self.data_addresses = vector.vector[voidptr](n_args) 

270 for i, arg in enumerate(kernel_args): 

271 arg_type = type(arg) 

272 if arg_type is Buffer: 

273 # we need the address of where the actual buffer address is stored 

274 if type(arg.handle) is int: 

275 # see note below on handling int arguments 

276 prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) 

277 continue 

278 else: 

279 # it's a CUdeviceptr: 

280 self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr()) 

281 continue 

282 elif arg_type is bool: 

283 prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) 

284 continue 

285 elif arg_type is int: 

286 # Here's the dilemma: We want to have a fast path to pass in Python 

287 # integers as pointer addresses, but one could also (mistakenly) pass 

288 # it with the intention of passing a scalar integer. It's a mistake 

289 # bacause a Python int is ambiguous (arbitrary width). Our judgement 

290 # call here is to treat it as a pointer address, without any warning! 

291 prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) 

292 continue 

293 elif arg_type is float: 

294 prepare_arg[double](self.data, self.data_addresses, arg, i) 

295 continue 

296 elif arg_type is complex: 

297 prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) 

298 continue 

299  

300 not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i) 

301 if not_prepared: 

302 not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i) 

303 if not_prepared: 

304 # TODO: revisit this treatment if we decide to cythonize cuda.core 

305 if arg_type is driver.CUgraphConditionalHandle: 

306 prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, <intptr_t>int(arg), i) 

307 continue 

308 # If no exact types are found, fallback to slower `isinstance` check 

309 elif isinstance(arg, Buffer): 

310 if isinstance(arg.handle, int): 

311 prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) 

312 continue 

313 else: 

314 self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr()) 

315 continue 

316 elif isinstance(arg, bool): 

317 prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) 

318 continue 

319 elif isinstance(arg, int): 

320 prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) 

321 continue 

322 elif isinstance(arg, float): 

323 prepare_arg[double](self.data, self.data_addresses, arg, i) 

324 continue 

325 elif isinstance(arg, complex): 

326 prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) 

327 continue 

328 elif isinstance(arg, driver.CUgraphConditionalHandle): 

329 prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg, i) 

330 continue 

331 # TODO: support ctypes/numpy struct 

332 raise TypeError("the argument is of unsupported type: " + str(type(arg))) 

333  

334 self.kernel_args = kernel_args 

335 self.ptr = <intptr_t>self.data_addresses.data() 

336  

337 def __dealloc__(self): 

338 for data in self.data: 

339 if data: 

340 PyMem_Free(data)