Coverage for cuda / core / _linker.pyx: 63.54%

362 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-08 01:07 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4"""Linking machinery for combining object codes. 

5  

6This module provides :class:`Linker` for linking one or more 

7:class:`~cuda.core.ObjectCode` objects, with :class:`LinkerOptions` for 

8configuration. 

9""" 

10  

11from __future__ import annotations 

12  

13from cpython.bytearray cimport PyByteArray_AS_STRING 

14from libc.stdint cimport intptr_t, uint32_t 

15from libcpp.vector cimport vector 

16from cuda.bindings cimport cydriver 

17from cuda.bindings cimport cynvjitlink 

18  

19from ._resource_handles cimport ( 

20 as_cu, 

21 as_py, 

22 create_culink_handle, 

23 create_nvjitlink_handle, 

24) 

25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, HANDLE_RETURN_NVJITLINK 

26  

27import sys 

28from dataclasses import dataclass 

29from typing import Union 

30from warnings import warn 

31  

32from cuda.pathfinder import optional_cuda_import 

33from cuda.core._device import Device 

34from cuda.core._module import ObjectCode 

35from cuda.core._utils.clear_error_support import assert_type 

36from cuda.core._utils.cuda_utils import ( 

37 CUDAError, 

38 check_or_create_options, 

39 driver, 

40 handle_return, 

41 is_sequence, 

42) 

43  

44ctypedef const char* const_char_ptr 

45ctypedef void* void_ptr 

46  

47__all__ = ["Linker", "LinkerOptions"] 

48  

49LinkerHandleT = Union["cuda.bindings.nvjitlink.nvJitLinkHandle", "cuda.bindings.driver.CUlinkState"] 

50  

51  

52# ============================================================================= 

53# Principal class 

54# ============================================================================= 

55  

56cdef class Linker: 

57 """Represent a linking machinery to link one or more object codes into 

58 :class:`~cuda.core.ObjectCode`. 

59  

60 This object provides a unified interface to multiple underlying 

61 linker libraries (such as nvJitLink or cuLink* from the CUDA driver). 

62  

63 Parameters 

64 ---------- 

65 object_codes : :class:`~cuda.core.ObjectCode` 

66 One or more ObjectCode objects to be linked. 

67 options : :class:`LinkerOptions`, optional 

68 Options for the linker. If not provided, default options will be used. 

69 """ 

70  

71 def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None): 

72 Linker_init(self, object_codes, options) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

73  

74 def link(self, target_type) -> ObjectCode: 

75 """Link the provided object codes into a single output of the specified target type. 

76  

77 Parameters 

78 ---------- 

79 target_type : str 

80 The type of the target output. Must be either "cubin" or "ptx". 

81  

82 Returns 

83 ------- 

84 :class:`~cuda.core.ObjectCode` 

85 The linked object code of the specified target type. 

86  

87 .. note:: 

88  

89 Ensure that input object codes were compiled with appropriate 

90 flags for linking (e.g., relocatable device code enabled). 

91 """ 

92 return Linker_link(self, target_type) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

93  

94 def get_error_log(self) -> str: 

95 """Get the error log generated by the linker. 

96  

97 Returns 

98 ------- 

99 str 

100 The error log. 

101 """ 

102 # After link(), the decoded log is cached here. 

103 if self._error_log is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

104 return self._error_log 

105 cdef cynvjitlink.nvJitLinkHandle c_h 

106 cdef size_t c_log_size = 0 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

107 cdef char* c_log_ptr 

108 if self._use_nvjitlink: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

109 c_h = as_cu(self._nvjitlink_handle) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

110 cynvjitlink.nvJitLinkGetErrorLogSize(c_h, &c_log_size) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

111 log = bytearray(c_log_size) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

112 if c_log_size > 0: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

113 c_log_ptr = <char*>(<bytearray>log) 1M

114 cynvjitlink.nvJitLinkGetErrorLog(c_h, c_log_ptr) 1M

115 return log.decode("utf-8", errors="backslashreplace") 1LMmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

116 else: 

117 return (<bytearray>self._drv_log_bufs[2]).decode( 

118 "utf-8", errors="backslashreplace").rstrip('\x00') 

119  

120 def get_info_log(self) -> str: 

121 """Get the info log generated by the linker. 

122  

123 Returns 

124 ------- 

125 str 

126 The info log. 

127 """ 

128 # After link(), the decoded log is cached here. 

129 if self._info_log is not None: 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

130 return self._info_log 1m

131 cdef cynvjitlink.nvJitLinkHandle c_h 

132 cdef size_t c_log_size = 0 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

133 cdef char* c_log_ptr 

134 if self._use_nvjitlink: 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

135 c_h = as_cu(self._nvjitlink_handle) 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

136 cynvjitlink.nvJitLinkGetInfoLogSize(c_h, &c_log_size) 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

137 log = bytearray(c_log_size) 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

138 if c_log_size > 0: 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

139 c_log_ptr = <char*>(<bytearray>log) 1qrsnjkab

140 cynvjitlink.nvJitLinkGetInfoLog(c_h, c_log_ptr) 1qrsnjkab

141 return log.decode("utf-8", errors="backslashreplace") 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

142 else: 

143 return (<bytearray>self._drv_log_bufs[0]).decode( 

144 "utf-8", errors="backslashreplace").rstrip('\x00') 

145  

146 def close(self): 

147 """Destroy this linker.""" 

148 if self._use_nvjitlink: 1cadefghbi

149 self._nvjitlink_handle.reset() 1cadefghbi

150 else: 

151 self._culink_handle.reset() 

152  

153 @property 

154 def handle(self) -> LinkerHandleT: 

155 """Return the underlying handle object. 

156  

157 .. note:: 

158  

159 The type of the returned object depends on the backend. 

160  

161 .. caution:: 

162  

163 This handle is a Python object. To get the memory address of the underlying C 

164 handle, call ``int(Linker.handle)``. 

165 """ 

166 if self._use_nvjitlink: 

167 return as_py(self._nvjitlink_handle) 

168 else: 

169 return as_py(self._culink_handle) 

170  

171 @property 

172 def backend(self) -> str: 

173 """Return this Linker instance's underlying backend.""" 

174 return "nvJitLink" if self._use_nvjitlink else "driver" 1KqzArBsCnjkDEFGHIvowxpyJcadefghbi

175  

176  

177# ============================================================================= 

178# Supporting classes 

179# ============================================================================= 

180  

181@dataclass 

182class LinkerOptions: 

183 """Customizable options for configuring :class:`Linker`. 

184  

185 Since the linker may choose to use nvJitLink or the driver APIs as the linking backend, 

186 not all options are applicable. When the system's installed nvJitLink is too old (<12.3), 

187 or not installed, the driver APIs (cuLink) will be used instead. 

188  

189 Attributes 

190 ---------- 

191 name : str, optional 

192 Name of the linker. If the linking succeeds, the name is passed down to the generated `ObjectCode`. 

193 arch : str, optional 

194 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or 

195 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture 

196 will be used. 

197 max_register_count : int, optional 

198 Maximum register count. 

199 time : bool, optional 

200 Print timing information to the info log. 

201 Default: False. 

202 verbose : bool, optional 

203 Print verbose messages to the info log. 

204 Default: False. 

205 link_time_optimization : bool, optional 

206 Perform link time optimization. 

207 Default: False. 

208 ptx : bool, optional 

209 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``. 

210 Default: False. 

211 optimization_level : int, optional 

212 Set optimization level. Only 0 and 3 are accepted. 

213 debug : bool, optional 

214 Generate debug information. 

215 Default: False. 

216 lineinfo : bool, optional 

217 Generate line information. 

218 Default: False. 

219 ftz : bool, optional 

220 Flush denormal values to zero. 

221 Default: False. 

222 prec_div : bool, optional 

223 Use precise division. 

224 Default: True. 

225 prec_sqrt : bool, optional 

226 Use precise square root. 

227 Default: True. 

228 fma : bool, optional 

229 Use fast multiply-add. 

230 Default: True. 

231 kernels_used : [str | tuple[str] | list[str]], optional 

232 Pass a kernel or sequence of kernels that are used; any not in the list can be removed. 

233 variables_used : [str | tuple[str] | list[str]], optional 

234 Pass a variable or sequence of variables that are used; any not in the list can be removed. 

235 optimize_unused_variables : bool, optional 

236 Assume that if a variable is not referenced in device code, it can be removed. 

237 Default: False. 

238 ptxas_options : [str | tuple[str] | list[str]], optional 

239 Pass options to PTXAS. 

240 split_compile : int, optional 

241 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split 

242 compilation (default). 

243 Default: 1. 

244 split_compile_extended : int, optional 

245 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value. 

246 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This 

247 option can potentially impact performance of the compiled binary. 

248 Default: 1. 

249 no_cache : bool, optional 

250 Do not cache the intermediate steps of nvJitLink. 

251 Default: False. 

252 """ 

253  

254 name: str | None = "<default linker>" 

255 arch: str | None = None 

256 max_register_count: int | None = None 

257 time: bool | None = None 

258 verbose: bool | None = None 

259 link_time_optimization: bool | None = None 

260 ptx: bool | None = None 

261 optimization_level: int | None = None 

262 debug: bool | None = None 

263 lineinfo: bool | None = None 

264 ftz: bool | None = None 

265 prec_div: bool | None = None 

266 prec_sqrt: bool | None = None 

267 fma: bool | None = None 

268 kernels_used: str | tuple[str] | list[str] | None = None 

269 variables_used: str | tuple[str] | list[str] | None = None 

270 optimize_unused_variables: bool | None = None 

271 ptxas_options: str | tuple[str] | list[str] | None = None 

272 split_compile: int | None = None 

273 split_compile_extended: int | None = None 

274 no_cache: bool | None = None 

275  

276 def __post_init__(self): 

277 _lazy_init() 1LMmOtNulRPcadefghbi

278 self._name = self.name.encode() 1LMmOtNulRPcadefghbi

279  

280 def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> list[bytes] | list[str]: 

281 options = [] 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

282  

283 if self.arch is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

284 options.append(f"-arch={self.arch}") 1MmqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

285 else: 

286 options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 1K

287 if self.max_register_count is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

288 options.append(f"-maxrregcount={self.max_register_count}") 1zPc

289 if self.time is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

290 options.append("-time") 1sb

291 if self.verbose: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

292 options.append("-verbose") 1q

293 if self.link_time_optimization: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

294 options.append("-lto") 1l

295 if self.ptx: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

296 options.append("-ptx") 1Ol

297 if self.optimization_level is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

298 options.append(f"-O{self.optimization_level}") 1A

299 if self.debug: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

300 options.append("-g") 1rPa

301 if self.lineinfo: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

302 options.append("-lineinfo") 1Bd

303 if self.ftz is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

304 options.append(f"-ftz={'true' if self.ftz else 'false'}") 1FPe

305 if self.prec_div is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

306 options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 1Gf

307 if self.prec_sqrt is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

308 options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 1Hg

309 if self.fma is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

310 options.append(f"-fma={'true' if self.fma else 'false'}") 1Ih

311 if self.kernels_used is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

312 if isinstance(self.kernels_used, str): 1vow

313 options.append(f"-kernels-used={self.kernels_used}") 1v

314 elif isinstance(self.kernels_used, list): 1ow

315 for kernel in self.kernels_used: 1o

316 options.append(f"-kernels-used={kernel}") 1o

317 if self.variables_used is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

318 if isinstance(self.variables_used, str): 1xpy

319 options.append(f"-variables-used={self.variables_used}") 1x

320 elif isinstance(self.variables_used, list): 1py

321 for variable in self.variables_used: 1p

322 options.append(f"-variables-used={variable}") 1p

323 if self.optimize_unused_variables is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

324 options.append("-optimize-unused-variables") 1C

325 if self.ptxas_options is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

326 if isinstance(self.ptxas_options, str): 1njk

327 options.append(f"-Xptxas={self.ptxas_options}") 1n

328 elif is_sequence(self.ptxas_options): 1jk

329 for opt in self.ptxas_options: 1jk

330 options.append(f"-Xptxas={opt}") 1jk

331 if self.split_compile is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

332 options.append(f"-split-compile={self.split_compile}") 1Di

333 if self.split_compile_extended is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

334 options.append(f"-split-compile-extended={self.split_compile_extended}") 1E

335 if self.no_cache is True: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

336 options.append("-no-cache") 1J

337  

338 if as_bytes: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

339 return [o.encode() for o in options] 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi

340 else: 

341 return options 

342  

343 def _prepare_driver_options(self) -> tuple[list, list]: 

344 formatted_options = [] 

345 option_keys = [] 

346  

347 # allocate a fixed-sized buffer for each info/error log 

348 size = 4194304 

349 formatted_options.extend((bytearray(size), size, bytearray(size), size)) 

350 option_keys.extend( 

351 ( 

352 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER, 

353 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, 

354 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER, 

355 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, 

356 ) 

357 ) 

358  

359 if self.arch is not None: 

360 arch = self.arch.split("_")[-1].upper() 

361 formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}")) 

362 option_keys.append(_driver.CUjit_option.CU_JIT_TARGET) 

363 if self.max_register_count is not None: 

364 formatted_options.append(self.max_register_count) 

365 option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS) 

366 if self.time is not None: 

367 raise ValueError("time option is not supported by the driver API") 

368 if self.verbose: 

369 formatted_options.append(1) 

370 option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE) 

371 if self.link_time_optimization: 

372 formatted_options.append(1) 

373 option_keys.append(_driver.CUjit_option.CU_JIT_LTO) 

374 if self.ptx: 

375 raise ValueError("ptx option is not supported by the driver API") 

376 if self.optimization_level is not None: 

377 formatted_options.append(self.optimization_level) 

378 option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL) 

379 if self.debug: 

380 formatted_options.append(1) 

381 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO) 

382 if self.lineinfo: 

383 formatted_options.append(1) 

384 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO) 

385 if self.ftz is not None: 

386 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

387 if self.prec_div is not None: 

388 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

389 if self.prec_sqrt is not None: 

390 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

391 if self.fma is not None: 

392 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

393 if self.kernels_used is not None: 

394 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

395 if self.variables_used is not None: 

396 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

397 if self.optimize_unused_variables is not None: 

398 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

399 if self.ptxas_options is not None: 

400 raise ValueError("ptxas_options option is not supported by the driver API") 

401 if self.split_compile is not None: 

402 raise ValueError("split_compile option is not supported by the driver API") 

403 if self.split_compile_extended is not None: 

404 raise ValueError("split_compile_extended option is not supported by the driver API") 

405 if self.no_cache is True: 

406 formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE) 

407 option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE) 

408  

409 return formatted_options, option_keys 

410  

411 def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]: 

412 """Convert linker options to bytes format for the nvjitlink backend. 

413  

414 Parameters 

415 ---------- 

416 backend : str, optional 

417 The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink". 

418  

419 Returns 

420 ------- 

421 list[bytes] 

422 List of option strings encoded as bytes. 

423  

424 Raises 

425 ------ 

426 ValueError 

427 If an unsupported backend is specified. 

428 RuntimeError 

429 If nvJitLink backend is not available. 

430 """ 

431 backend = backend.lower() 1RP

432 if backend != "nvjitlink": 1RP

433 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") 1R

434 if not _use_nvjitlink_backend: 1P

435 raise RuntimeError("nvJitLink backend is not available") 

436 return self._prepare_nvjitlink_options(as_bytes=True) 1P

437  

438  

439# ============================================================================= 

440# Private implementation: cdef inline helpers 

441# ============================================================================= 

442  

443cdef inline int Linker_init(Linker self, tuple object_codes, object options) except -1: 

444 """Initialize a Linker instance.""" 

445 if len(object_codes) == 0: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

446 raise ValueError("At least one ObjectCode object must be provided") 

447  

448 cdef cynvjitlink.nvJitLinkHandle c_raw_nvjitlink 

449 cdef cydriver.CUlinkState c_raw_culink 

450 cdef Py_ssize_t c_num_opts, i 

451 cdef vector[const_char_ptr] c_str_opts 

452 cdef vector[cydriver.CUjit_option] c_jit_keys 

453 cdef vector[void_ptr] c_jit_values 

454  

455 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

456  

457 if _use_nvjitlink_backend: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

458 self._use_nvjitlink = True 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

459 options_bytes = options._prepare_nvjitlink_options(as_bytes=True) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

460 c_num_opts = len(options_bytes) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

461 c_str_opts.resize(c_num_opts) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

462 for i in range(c_num_opts): 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

463 c_str_opts[i] = <const char*>(<bytes>options_bytes[i]) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

464 with nogil: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

465 HANDLE_RETURN_NVJITLINK(NULL, cynvjitlink.nvJitLinkCreate( 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi

466 &c_raw_nvjitlink, <uint32_t>c_num_opts, c_str_opts.data())) 

467 self._nvjitlink_handle = create_nvjitlink_handle(c_raw_nvjitlink) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

468 else: 

469 self._use_nvjitlink = False 

470 formatted_options, option_keys = options._prepare_driver_options() 

471 # Keep the formatted_options list alive: it contains bytearrays that 

472 # the driver writes into via raw pointers during linking operations. 

473 self._drv_log_bufs = formatted_options 

474 c_num_opts = len(option_keys) 

475 c_jit_keys.resize(c_num_opts) 

476 c_jit_values.resize(c_num_opts) 

477 for i in range(c_num_opts): 

478 c_jit_keys[i] = <cydriver.CUjit_option><int>option_keys[i] 

479 val = formatted_options[i] 

480 if isinstance(val, bytearray): 

481 c_jit_values[i] = <void*>PyByteArray_AS_STRING(val) 

482 else: 

483 c_jit_values[i] = <void*><intptr_t>int(val) 

484 try: 

485 with nogil: 

486 HANDLE_RETURN(cydriver.cuLinkCreate( 

487 <unsigned int>c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink)) 

488 except CUDAError as e: 

489 Linker_annotate_error_log(self, e) 

490 raise 

491 self._culink_handle = create_culink_handle(c_raw_culink) 

492  

493 for code in object_codes: 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

494 assert_type(code, ObjectCode) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

495 Linker_add_code_object(self, code) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

496 return 0 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

497  

498  

499cdef inline void Linker_add_code_object(Linker self, object object_code) except *: 

500 """Add a single ObjectCode to the linker.""" 

501 data = object_code.code 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

502 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h 

503 cdef cydriver.CUlinkState c_culink_state 

504 cdef cynvjitlink.nvJitLinkInputType c_nv_input_type 

505 cdef cydriver.CUjitInputType c_drv_input_type 

506 cdef const char* c_data_ptr 

507 cdef size_t c_data_size 

508 cdef const char* c_name_ptr 

509 cdef const char* c_file_ptr 

510  

511 name_bytes = f"{object_code.name}".encode() 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

512 c_name_ptr = <const char*>name_bytes 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

513  

514 input_types = _nvjitlink_input_types if self._use_nvjitlink else _driver_input_types 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

515 py_input_type = input_types.get(object_code.code_type) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

516 if py_input_type is None: 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

517 raise ValueError(f"Unknown code_type associated with ObjectCode: {object_code.code_type}") 

518  

519 if self._use_nvjitlink: 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

520 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

521 c_nv_input_type = <cynvjitlink.nvJitLinkInputType><int>py_input_type 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

522 if isinstance(data, bytes): 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

523 c_data_ptr = <const char*>(<bytes>data) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

524 c_data_size = len(data) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

525 with nogil: 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

526 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddData( 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

527 c_nvjitlink_h, c_nv_input_type, <const void*>c_data_ptr, c_data_size, c_name_ptr)) 

528 elif isinstance(data, str): 

529 file_bytes = data.encode() 

530 c_file_ptr = <const char*>file_bytes 

531 with nogil: 

532 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddFile( 

533 c_nvjitlink_h, c_nv_input_type, c_file_ptr)) 

534 else: 

535 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") 

536 else: 

537 c_culink_state = as_cu(self._culink_handle) 

538 c_drv_input_type = <cydriver.CUjitInputType><int>py_input_type 

539 try: 

540 if isinstance(data, bytes): 

541 c_data_ptr = <const char*>(<bytes>data) 

542 c_data_size = len(data) 

543 with nogil: 

544 HANDLE_RETURN(cydriver.cuLinkAddData( 

545 c_culink_state, c_drv_input_type, <void*>c_data_ptr, c_data_size, c_name_ptr, 

546 0, NULL, NULL)) 

547 elif isinstance(data, str): 

548 file_bytes = data.encode() 

549 c_file_ptr = <const char*>file_bytes 

550 with nogil: 

551 HANDLE_RETURN(cydriver.cuLinkAddFile( 

552 c_culink_state, c_drv_input_type, c_file_ptr, 0, NULL, NULL)) 

553 else: 

554 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") 

555 except CUDAError as e: 

556 Linker_annotate_error_log(self, e) 

557 raise 

558  

559  

560cdef inline object Linker_link(Linker self, str target_type): 

561 """Complete linking and return the result as ObjectCode.""" 

562 if target_type not in ("cubin", "ptx"): 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi

563 raise ValueError(f"Unsupported target type: {target_type}") 1N

564  

565 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h 

566 cdef cydriver.CUlinkState c_culink_state 

567 cdef size_t c_output_size = 0 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

568 cdef char* c_code_ptr 

569 cdef void* c_cubin_out = NULL 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

570  

571 if self._use_nvjitlink: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

572 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

573 with nogil: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

574 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkComplete(c_nvjitlink_h)) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

575 if target_type == "cubin": 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

576 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi

577 cynvjitlink.nvJitLinkGetLinkedCubinSize(c_nvjitlink_h, &c_output_size)) 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi

578 code = bytearray(c_output_size) 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi

579 c_code_ptr = <char*>(<bytearray>code) 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi

580 with nogil: 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi

581 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi

582 cynvjitlink.nvJitLinkGetLinkedCubin(c_nvjitlink_h, c_code_ptr)) 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi

583 else: 

584 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1l

585 cynvjitlink.nvJitLinkGetLinkedPtxSize(c_nvjitlink_h, &c_output_size)) 1l

586 code = bytearray(c_output_size) 1l

587 c_code_ptr = <char*>(<bytearray>code) 1l

588 with nogil: 1l

589 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1l

590 cynvjitlink.nvJitLinkGetLinkedPtx(c_nvjitlink_h, c_code_ptr)) 1l

591 else: 

592 c_culink_state = as_cu(self._culink_handle) 

593 try: 

594 with nogil: 

595 HANDLE_RETURN(cydriver.cuLinkComplete(c_culink_state, &c_cubin_out, &c_output_size)) 

596 except CUDAError as e: 

597 Linker_annotate_error_log(self, e) 

598 raise 

599 code = (<char*>c_cubin_out)[:c_output_size] 

600  

601 # Linking is complete; cache the decoded log strings and release 

602 # the driver's raw bytearray buffers (no longer written to). 

603 self._info_log = self.get_info_log() 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

604 self._error_log = self.get_error_log() 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

605 self._drv_log_bufs = None 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

606  

607 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi

608  

609  

610cdef inline void Linker_annotate_error_log(Linker self, object e): 

611 """Annotate a CUDAError with the driver linker error log.""" 

612 error_log = self.get_error_log() 

613 if error_log: 

614 e.args = (e.args[0] + f"\nLinker error log: {error_log}", *e.args[1:]) 

615  

616  

617# ============================================================================= 

618# Private implementation: module-level state and initialization 

619# ============================================================================= 

620  

621# TODO: revisit this treatment for py313t builds 

622_driver = None # populated if nvJitLink cannot be used 

623_driver_ver = None 

624_inited = False 

625_use_nvjitlink_backend = False # set by _decide_nvjitlink_or_driver() 

626  

627# Input type mappings populated by _lazy_init() with C-level enum ints. 

628_nvjitlink_input_types = None 

629_driver_input_types = None 

630  

631  

632def _nvjitlink_has_version_symbol(nvjitlink) -> bool: 

633 # This condition is equivalent to testing for version >= 12.3 

634 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion")) 

635  

636  

637# Note: this function is reused in the tests 

638def _decide_nvjitlink_or_driver() -> bool: 

639 """Return True if falling back to the cuLink* driver APIs.""" 

640 global _driver_ver, _driver, _use_nvjitlink_backend 

641 if _driver_ver is not None: 1LQS

642 return not _use_nvjitlink_backend 

643  

644 _driver_ver = handle_return(driver.cuDriverGetVersion()) 1LQS

645 _driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10) 1LQS

646  

647 warn_txt_common = ( 

648 "the driver APIs will be used instead, which do not support" 1LQS

649 " minor version compatibility or linking LTO IRs." 

650 " For best results, consider upgrading to a recent version of" 

651 ) 

652  

653 nvjitlink_module = optional_cuda_import( 1LQS

654 "cuda.bindings.nvjitlink", 

655 probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load 1LQS

656 ) 

657 if nvjitlink_module is None: 1LQ

658 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 1Q

659 else: 

660 from cuda.bindings._internal import nvjitlink 

661  

662 if _nvjitlink_has_version_symbol(nvjitlink): 

663 _use_nvjitlink_backend = True 

664 return False # Use nvjitlink 

665 warn_txt = ( 

666 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)." 

667 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink." 

668 ) 

669  

670 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 1Q

671 _driver = driver 1Q

672 return True 1Q

673  

674  

675def _lazy_init(): 

676 global _inited, _nvjitlink_input_types, _driver_input_types 

677 if _inited: 1LMmOtNulRPcadefghbi

678 return 1LMmOtNulRPcadefghbi

679  

680 _decide_nvjitlink_or_driver() 

681 if _use_nvjitlink_backend: 

682 _nvjitlink_input_types = { 

683 "ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX, 

684 "cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN, 

685 "fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN, 

686 "ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR, 

687 "object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT, 

688 "library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY, 

689 } 

690 else: 

691 _driver_input_types = { 

692 "ptx": <int>cydriver.CU_JIT_INPUT_PTX, 

693 "cubin": <int>cydriver.CU_JIT_INPUT_CUBIN, 

694 "fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY, 

695 "object": <int>cydriver.CU_JIT_INPUT_OBJECT, 

696 "library": <int>cydriver.CU_JIT_INPUT_LIBRARY, 

697 } 

698 _inited = True