Coverage for cuda / core / _linker.pyx: 64.64%

362 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-22 01:37 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4"""Linking machinery for combining object codes. 

5  

6This module provides :class:`Linker` for linking one or more 

7:class:`~cuda.core.ObjectCode` objects, with :class:`LinkerOptions` for 

8configuration. 

9""" 

10  

11from __future__ import annotations 

12  

13from cpython.bytearray cimport PyByteArray_AS_STRING 

14from libc.stdint cimport intptr_t, uint32_t 

15from libcpp.vector cimport vector 

16from cuda.bindings cimport cydriver 

17from cuda.bindings cimport cynvjitlink 

18  

19from ._resource_handles cimport ( 

20 as_cu, 

21 as_py, 

22 create_culink_handle, 

23 create_nvjitlink_handle, 

24) 

25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, HANDLE_RETURN_NVJITLINK 

26  

27import sys 

28from dataclasses import dataclass 

29from typing import Union 

30from warnings import warn 

31  

32from cuda.pathfinder._optional_cuda_import import _optional_cuda_import 

33from cuda.core._device import Device 

34from cuda.core._module import ObjectCode 

35from cuda.core._utils.clear_error_support import assert_type 

36from cuda.core._utils.cuda_utils import ( 

37 CUDAError, 

38 check_or_create_options, 

39 driver, 

40 is_sequence, 

41) 

42from cuda.core.typing import CompilerBackendType 

43  

44ctypedef const char* const_char_ptr 

45ctypedef void* void_ptr 

46  

47__all__ = ["Linker", "LinkerOptions"] 

48  

49LinkerHandleT = Union["cuda.bindings.nvjitlink.nvJitLinkHandle", "cuda.bindings.driver.CUlinkState"] 

50  

51  

52# ============================================================================= 

53# Principal class 

54# ============================================================================= 

55  

56cdef class Linker: 

57 """Represent a linking machinery to link one or more object codes into 

58 :class:`~cuda.core.ObjectCode`. 

59  

60 This object provides a unified interface to multiple underlying 

61 linker libraries (such as nvJitLink or cuLink* from the CUDA driver). 

62  

63 Parameters 

64 ---------- 

65 object_codes : :class:`~cuda.core.ObjectCode` 

66 One or more ObjectCode objects to be linked. 

67 options : :class:`LinkerOptions`, optional 

68 Options for the linker. If not provided, default options will be used. 

69 """ 

70  

71 def __init__(self, *object_codes: ObjectCode, options: "LinkerOptions" = None): 

72 Linker_init(self, object_codes, options) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

73  

74 def link(self, target_type: ObjectCodeFormatType | str) -> ObjectCode: 

75 """Link the provided object codes into a single output of the specified target type. 

76  

77 Parameters 

78 ---------- 

79 target_type : ObjectCodeFormatType | str 

80 The type of the target output. Must be either "cubin" or "ptx". 

81  

82 Returns 

83 ------- 

84 :class:`~cuda.core.ObjectCode` 

85 The linked object code of the specified target type. 

86  

87 .. note:: 

88  

89 Ensure that input object codes were compiled with appropriate 

90 flags for linking (e.g., relocatable device code enabled). 

91 """ 

92 return Linker_link(self, str(target_type)) 1OtMpzAqBrCuljDEFGHIvmwxnyJKQLksocadefghbi

93  

94 def get_error_log(self) -> str: 

95 """Get the error log generated by the linker. 

96  

97 Returns 

98 ------- 

99 str 

100 The error log. 

101 """ 

102 # After link(), the decoded log is cached here. 

103 if self._error_log is not None: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

104 return self._error_log 1s

105 cdef cynvjitlink.nvJitLinkHandle c_h 

106 cdef size_t c_log_size = 0 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

107 cdef char* c_log_ptr 

108 if self._use_nvjitlink: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

109 c_h = as_cu(self._nvjitlink_handle) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

110 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetErrorLogSize(c_h, &c_log_size)) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

111 log = bytearray(c_log_size) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

112 if c_log_size > 0: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

113 c_log_ptr = <char*>(<bytearray>log) 1O

114 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetErrorLog(c_h, c_log_ptr)) 1O

115 return log.decode("utf-8", errors="backslashreplace") 1NOtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

116 else: 

117 return (<bytearray>self._drv_log_bufs[2]).decode( 

118 "utf-8", errors="backslashreplace").rstrip('\x00') 

119  

120 def get_info_log(self) -> str: 

121 """Get the info log generated by the linker. 

122  

123 Returns 

124 ------- 

125 str 

126 The info log. 

127 """ 

128 # After link(), the decoded log is cached here. 

129 if self._info_log is not None: 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

130 return self._info_log 1ts

131 cdef cynvjitlink.nvJitLinkHandle c_h 

132 cdef size_t c_log_size = 0 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

133 cdef char* c_log_ptr 

134 if self._use_nvjitlink: 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

135 c_h = as_cu(self._nvjitlink_handle) 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

136 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetInfoLogSize(c_h, &c_log_size)) 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

137 log = bytearray(c_log_size) 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

138 if c_log_size > 0: 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

139 c_log_ptr = <char*>(<bytearray>log) 1pqrjkab

140 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetInfoLog(c_h, c_log_ptr)) 1pqrjkab

141 return log.decode("utf-8", errors="backslashreplace") 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

142 else: 

143 return (<bytearray>self._drv_log_bufs[0]).decode( 

144 "utf-8", errors="backslashreplace").rstrip('\x00') 

145  

146 def close(self): 

147 """Destroy this linker.""" 

148 if self._use_nvjitlink: 1Pcadefghbi

149 self._nvjitlink_handle.reset() 1Pcadefghbi

150 else: 

151 self._culink_handle.reset() 

152  

153 @property 

154 def handle(self) -> LinkerHandleT: 

155 """Return the underlying handle object. 

156  

157 .. note:: 

158  

159 The type of the returned object depends on the backend. 

160  

161 .. caution:: 

162  

163 This handle is a Python object. To get the memory address of the underlying C 

164 handle, call ``int(Linker.handle)``. 

165 """ 

166 if self._use_nvjitlink: 1SP

167 return as_py(self._nvjitlink_handle) 1SP

168 else: 

169 return as_py(self._culink_handle) 

170  

171 @classmethod 

172 def which_backend(cls) -> CompilerBackendType: 

173 """Return which linking backend will be used. 

174  

175 Returns :attr:`~CompilerBackendType.NVJITLINK` when the nvJitLink 

176 library is available and meets the minimum version requirement, 

177 otherwise :attr:`~CompilerBackendType.DRIVER`. 

178  

179 .. note:: 

180  

181 Prefer letting :class:`Linker` decide. Query ``which_backend()`` 

182 only when you need to dispatch based on input format (for 

183 example: choose PTX vs. LTOIR before constructing a 

184 ``Linker``). The returned value names an implementation 

185 detail whose support matrix may shift across CTK releases. 

186 """ 

187 return CompilerBackendType.DRIVER if _decide_nvjitlink_or_driver() else CompilerBackendType.NVJITLINK 1?Z0MpzAqBrCuljDEFGHIvmwxnyJRoPcadefghbi

188  

189  

190# ============================================================================= 

191# Supporting classes 

192# ============================================================================= 

193  

194@dataclass 

195class LinkerOptions: 

196 """Customizable options for configuring :class:`Linker`. 

197  

198 Since the linker may choose to use nvJitLink or the driver APIs as the linking backend, 

199 not all options are applicable. When the system's installed nvJitLink is too old (<12.3), 

200 or not installed, the driver APIs (cuLink) will be used instead. 

201  

202 Attributes 

203 ---------- 

204 name : str, optional 

205 Name of the linker. If the linking succeeds, the name is passed down to the generated :class:`ObjectCode`. 

206 arch : str, optional 

207 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or 

208 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture 

209 will be used. 

210 max_register_count : int, optional 

211 Maximum register count. 

212 time : bool, optional 

213 Print timing information to the info log. 

214 Default: False. 

215 verbose : bool, optional 

216 Print verbose messages to the info log. 

217 Default: False. 

218 link_time_optimization : bool, optional 

219 Perform link time optimization. 

220 Default: False. 

221 ptx : bool, optional 

222 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``. 

223 Default: False. 

224 optimization_level : int, optional 

225 Set optimization level. Only 0 and 3 are accepted. 

226 debug : bool, optional 

227 Generate debug information. 

228 Default: False. 

229 lineinfo : bool, optional 

230 Generate line information. 

231 Default: False. 

232 ftz : bool, optional 

233 Flush denormal values to zero. 

234 Default: False. 

235 prec_div : bool, optional 

236 Use precise division. 

237 Default: True. 

238 prec_sqrt : bool, optional 

239 Use precise square root. 

240 Default: True. 

241 fma : bool, optional 

242 Use fast multiply-add. 

243 Default: True. 

244 kernels_used : [str | tuple[str] | list[str]], optional 

245 Pass a kernel or sequence of kernels that are used; any not in the list can be removed. 

246 variables_used : [str | tuple[str] | list[str]], optional 

247 Pass a variable or sequence of variables that are used; any not in the list can be removed. 

248 optimize_unused_variables : bool, optional 

249 Assume that if a variable is not referenced in device code, it can be removed. 

250 Default: False. 

251 ptxas_options : [str | tuple[str] | list[str]], optional 

252 Pass options to PTXAS. 

253 split_compile : int, optional 

254 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split 

255 compilation (default). 

256 Default: 1. 

257 split_compile_extended : int, optional 

258 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value. 

259 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This 

260 option can potentially impact performance of the compiled binary. 

261 Default: 1. 

262 no_cache : bool, optional 

263 Do not cache the intermediate steps of nvJitLink. 

264 Default: False. 

265 """ 

266  

267 name: str | None = "<default linker>" 

268 arch: str | None = None 

269 max_register_count: int | None = None 

270 time: bool | None = None 

271 verbose: bool | None = None 

272 link_time_optimization: bool | None = None 

273 ptx: bool | None = None 

274 optimization_level: int | None = None 

275 debug: bool | None = None 

276 lineinfo: bool | None = None 

277 ftz: bool | None = None 

278 prec_div: bool | None = None 

279 prec_sqrt: bool | None = None 

280 fma: bool | None = None 

281 kernels_used: str | tuple[str] | list[str] | None = None 

282 variables_used: str | tuple[str] | list[str] | None = None 

283 optimize_unused_variables: bool | None = None 

284 ptxas_options: str | tuple[str] | list[str] | None = None 

285 split_compile: int | None = None 

286 split_compile_extended: int | None = None 

287 no_cache: bool | None = None 

288  

289 def __post_init__(self): 

290 _lazy_init() 1NOtSTKQLksXUVRoPcadefghbi

291 self._name = self.name.encode() 1NOtSTKQLksXUVRoPcadefghbi

292  

293 def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> list[bytes] | list[str]: 

294 options = [] 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

295  

296 if self.arch is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

297 options.append(f"-arch={self.arch}") 1OtSpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

298 else: 

299 options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 1M

300 if self.max_register_count is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

301 options.append(f"-maxrregcount={self.max_register_count}") 1zUc

302 if self.time is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

303 options.append("-time") 1rb

304 if self.verbose: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

305 options.append("-verbose") 1p

306 if self.link_time_optimization: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

307 options.append("-lto") 1k

308 if self.ptx: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

309 options.append("-ptx") 1Tk

310 if self.optimization_level is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

311 options.append(f"-O{self.optimization_level}") 1A

312 if self.debug: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

313 options.append("-g") 1qUVa

314 if self.lineinfo: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

315 options.append("-lineinfo") 1BVd

316 if self.ftz is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

317 options.append(f"-ftz={'true' if self.ftz else 'false'}") 1FUe

318 if self.prec_div is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

319 options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 1Gf

320 if self.prec_sqrt is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

321 options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 1Hg

322 if self.fma is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

323 options.append(f"-fma={'true' if self.fma else 'false'}") 1Ih

324 if self.kernels_used is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

325 if isinstance(self.kernels_used, str): 1vmw

326 options.append(f"-kernels-used={self.kernels_used}") 1v

327 elif isinstance(self.kernels_used, list): 1mw

328 for kernel in self.kernels_used: 1m

329 options.append(f"-kernels-used={kernel}") 1m

330 if self.variables_used is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

331 if isinstance(self.variables_used, str): 1xny

332 options.append(f"-variables-used={self.variables_used}") 1x

333 elif isinstance(self.variables_used, list): 1ny

334 for variable in self.variables_used: 1n

335 options.append(f"-variables-used={variable}") 1n

336 if self.optimize_unused_variables is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

337 options.append("-optimize-unused-variables") 1C

338 if self.ptxas_options is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

339 if isinstance(self.ptxas_options, str): 1ulj

340 options.append(f"-Xptxas={self.ptxas_options}") 1u

341 elif is_sequence(self.ptxas_options): 1lj

342 for opt in self.ptxas_options: 1lj

343 options.append(f"-Xptxas={opt}") 1lj

344 if self.split_compile is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

345 options.append(f"-split-compile={self.split_compile}") 1Di

346 if self.split_compile_extended is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

347 options.append(f"-split-compile-extended={self.split_compile_extended}") 1E

348 if self.no_cache is True: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

349 options.append("-no-cache") 1J

350  

351 if as_bytes: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi

352 return [o.encode() for o in options] 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksURoPcadefghbi

353 else: 

354 return options 1V

355  

356 def _prepare_driver_options(self) -> tuple[list, list]: 

357 formatted_options = [] 

358 option_keys = [] 

359  

360 # allocate a fixed-sized buffer for each info/error log 

361 size = 4194304 

362 formatted_options.extend((bytearray(size), size, bytearray(size), size)) 

363 option_keys.extend( 

364 ( 

365 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER, 

366 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, 

367 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER, 

368 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, 

369 ) 

370 ) 

371  

372 if self.arch is not None: 

373 arch = self.arch.split("_")[-1].upper() 

374 formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}")) 

375 option_keys.append(_driver.CUjit_option.CU_JIT_TARGET) 

376 if self.max_register_count is not None: 

377 formatted_options.append(self.max_register_count) 

378 option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS) 

379 if self.time is not None: 

380 raise ValueError("time option is not supported by the driver API") 

381 if self.verbose: 

382 formatted_options.append(1) 

383 option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE) 

384 if self.link_time_optimization: 

385 formatted_options.append(1) 

386 option_keys.append(_driver.CUjit_option.CU_JIT_LTO) 

387 if self.ptx: 

388 raise ValueError("ptx option is not supported by the driver API") 

389 if self.optimization_level is not None: 

390 formatted_options.append(self.optimization_level) 

391 option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL) 

392 if self.debug: 

393 formatted_options.append(1) 

394 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO) 

395 if self.lineinfo: 

396 formatted_options.append(1) 

397 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO) 

398 if self.ftz is not None: 

399 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

400 if self.prec_div is not None: 

401 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

402 if self.prec_sqrt is not None: 

403 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

404 if self.fma is not None: 

405 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

406 if self.kernels_used is not None: 

407 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

408 if self.variables_used is not None: 

409 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

410 if self.optimize_unused_variables is not None: 

411 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

412 if self.ptxas_options is not None: 

413 raise ValueError("ptxas_options option is not supported by the driver API") 

414 if self.split_compile is not None: 

415 raise ValueError("split_compile option is not supported by the driver API") 

416 if self.split_compile_extended is not None: 

417 raise ValueError("split_compile_extended option is not supported by the driver API") 

418 if self.no_cache is True: 

419 formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE) 

420 option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE) 

421  

422 return formatted_options, option_keys 

423  

424 def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]: 

425 """Convert linker options to bytes format for the nvjitlink backend. 

426  

427 Parameters 

428 ---------- 

429 backend : str, optional 

430 The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink". 

431  

432 Returns 

433 ------- 

434 list[bytes] 

435 List of option strings encoded as bytes. 

436  

437 Raises 

438 ------ 

439 ValueError 

440 If an unsupported backend is specified. 

441 RuntimeError 

442 If nvJitLink backend is not available. 

443 """ 

444 backend = backend.lower() 1XU

445 if backend != "nvjitlink": 1XU

446 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") 1X

447 if not _use_nvjitlink_backend: 1U

448 raise RuntimeError("nvJitLink backend is not available") 

449 return self._prepare_nvjitlink_options(as_bytes=True) 1U

450  

451  

452# ============================================================================= 

453# Private implementation: cdef inline helpers 

454# ============================================================================= 

455  

456cdef inline int Linker_init(Linker self, tuple object_codes, object options) except -1: 

457 """Initialize a Linker instance.""" 

458 if len(object_codes) == 0: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

459 raise ValueError("At least one ObjectCode object must be provided") 

460  

461 cdef cynvjitlink.nvJitLinkHandle c_raw_nvjitlink 

462 cdef cydriver.CUlinkState c_raw_culink 

463 cdef Py_ssize_t c_num_opts, i 

464 cdef vector[const_char_ptr] c_str_opts 

465 cdef vector[cydriver.CUjit_option] c_jit_keys 

466 cdef vector[void_ptr] c_jit_values 

467  

468 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

469  

470 if _use_nvjitlink_backend: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

471 self._use_nvjitlink = True 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

472 options_bytes = options._prepare_nvjitlink_options(as_bytes=True) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

473 c_num_opts = len(options_bytes) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

474 c_str_opts.resize(c_num_opts) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

475 for i in range(c_num_opts): 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

476 c_str_opts[i] = <const char*>(<bytes>options_bytes[i]) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

477 with nogil: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

478 HANDLE_RETURN_NVJITLINK(NULL, cynvjitlink.nvJitLinkCreate( 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi

479 &c_raw_nvjitlink, <uint32_t>c_num_opts, c_str_opts.data())) 

480 self._nvjitlink_handle = create_nvjitlink_handle(c_raw_nvjitlink) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

481 else: 

482 self._use_nvjitlink = False 

483 formatted_options, option_keys = options._prepare_driver_options() 

484 # Keep the formatted_options list alive: it contains bytearrays that 

485 # the driver writes into via raw pointers during linking operations. 

486 self._drv_log_bufs = formatted_options 

487 c_num_opts = len(option_keys) 

488 c_jit_keys.resize(c_num_opts) 

489 c_jit_values.resize(c_num_opts) 

490 for i in range(c_num_opts): 

491 c_jit_keys[i] = <cydriver.CUjit_option><int>option_keys[i] 

492 val = formatted_options[i] 

493 if isinstance(val, bytearray): 

494 c_jit_values[i] = <void*>PyByteArray_AS_STRING(val) 

495 else: 

496 c_jit_values[i] = <void*><intptr_t>int(val) 

497 try: 

498 with nogil: 

499 HANDLE_RETURN(cydriver.cuLinkCreate( 

500 <unsigned int>c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink)) 

501 except CUDAError as e: 

502 Linker_annotate_error_log(self, e) 

503 raise 

504 self._culink_handle = create_culink_handle(c_raw_culink) 

505  

506 for code in object_codes: 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

507 assert_type(code, ObjectCode) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

508 Linker_add_code_object(self, code) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

509 return 0 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

510  

511  

512cdef inline void Linker_add_code_object(Linker self, object object_code) except *: 

513 """Add a single ObjectCode to the linker.""" 

514 data = object_code.code 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

515 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h 

516 cdef cydriver.CUlinkState c_culink_state 

517 cdef cynvjitlink.nvJitLinkInputType c_nv_input_type 

518 cdef cydriver.CUjitInputType c_drv_input_type 

519 cdef const char* c_data_ptr 

520 cdef size_t c_data_size 

521 cdef const char* c_name_ptr 

522 cdef const char* c_file_ptr 

523  

524 name_bytes = f"{object_code.name}".encode() 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

525 c_name_ptr = <const char*>name_bytes 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

526  

527 input_types = _nvjitlink_input_types if self._use_nvjitlink else _driver_input_types 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

528 py_input_type = input_types.get(object_code.code_type) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

529 if py_input_type is None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

530 raise ValueError(f"Unknown code_type associated with ObjectCode: {object_code.code_type}") 

531  

532 if self._use_nvjitlink: 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

533 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

534 c_nv_input_type = <cynvjitlink.nvJitLinkInputType><int>py_input_type 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

535 if isinstance(data, bytes): 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

536 c_data_ptr = <const char*>(<bytes>data) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

537 c_data_size = len(data) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

538 with nogil: 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

539 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddData( 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi

540 c_nvjitlink_h, c_nv_input_type, <const void*>c_data_ptr, c_data_size, c_name_ptr)) 

541 elif isinstance(data, str): 

542 file_bytes = data.encode() 

543 c_file_ptr = <const char*>file_bytes 

544 with nogil: 

545 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddFile( 

546 c_nvjitlink_h, c_nv_input_type, c_file_ptr)) 

547 else: 

548 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") 

549 else: 

550 c_culink_state = as_cu(self._culink_handle) 

551 c_drv_input_type = <cydriver.CUjitInputType><int>py_input_type 

552 try: 

553 if isinstance(data, bytes): 

554 c_data_ptr = <const char*>(<bytes>data) 

555 c_data_size = len(data) 

556 with nogil: 

557 HANDLE_RETURN(cydriver.cuLinkAddData( 

558 c_culink_state, c_drv_input_type, <void*>c_data_ptr, c_data_size, c_name_ptr, 

559 0, NULL, NULL)) 

560 elif isinstance(data, str): 

561 file_bytes = data.encode() 

562 c_file_ptr = <const char*>file_bytes 

563 with nogil: 

564 HANDLE_RETURN(cydriver.cuLinkAddFile( 

565 c_culink_state, c_drv_input_type, c_file_ptr, 0, NULL, NULL)) 

566 else: 

567 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") 

568 except CUDAError as e: 

569 Linker_annotate_error_log(self, e) 

570 raise 

571  

572  

573cdef inline object Linker_link(Linker self, str target_type): 

574 """Complete linking and return the result as ObjectCode.""" 

575 if target_type not in ("cubin", "ptx"): 1OtMpzAqBrCuljDEFGHIvmwxnyJKQLksocadefghbi

576 raise ValueError(f"Unsupported target type: {target_type}") 1Q

577  

578 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h 

579 cdef cydriver.CUlinkState c_culink_state 

580 cdef size_t c_output_size = 0 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

581 cdef char* c_code_ptr 

582 cdef void* c_cubin_out = NULL 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

583  

584 if self._use_nvjitlink: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

585 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

586 with nogil: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

587 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkComplete(c_nvjitlink_h)) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

588 if target_type == "cubin": 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

589 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi

590 cynvjitlink.nvJitLinkGetLinkedCubinSize(c_nvjitlink_h, &c_output_size)) 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi

591 code = bytearray(c_output_size) 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi

592 c_code_ptr = <char*>(<bytearray>code) 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi

593 with nogil: 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi

594 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi

595 cynvjitlink.nvJitLinkGetLinkedCubin(c_nvjitlink_h, c_code_ptr)) 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi

596 else: 

597 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1k

598 cynvjitlink.nvJitLinkGetLinkedPtxSize(c_nvjitlink_h, &c_output_size)) 1k

599 code = bytearray(c_output_size) 1k

600 c_code_ptr = <char*>(<bytearray>code) 1k

601 with nogil: 1k

602 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1k

603 cynvjitlink.nvJitLinkGetLinkedPtx(c_nvjitlink_h, c_code_ptr)) 1k

604 else: 

605 c_culink_state = as_cu(self._culink_handle) 

606 try: 

607 with nogil: 

608 HANDLE_RETURN(cydriver.cuLinkComplete(c_culink_state, &c_cubin_out, &c_output_size)) 

609 except CUDAError as e: 

610 Linker_annotate_error_log(self, e) 

611 raise 

612 code = (<char*>c_cubin_out)[:c_output_size] 

613  

614 # Linking is complete; cache the decoded log strings and release 

615 # the driver's raw bytearray buffers (no longer written to). 

616 self._info_log = self.get_info_log() 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

617 self._error_log = self.get_error_log() 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

618 self._drv_log_bufs = None 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

619  

620 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi

621  

622  

623cdef inline void Linker_annotate_error_log(Linker self, object e): 

624 """Annotate a CUDAError with the driver linker error log.""" 

625 error_log = self.get_error_log() 

626 if error_log: 

627 e.args = (e.args[0] + f"\nLinker error log: {error_log}", *e.args[1:]) 

628  

629  

630# ============================================================================= 

631# Private implementation: module-level state and initialization 

632# ============================================================================= 

633  

634# TODO: revisit this treatment for py313t builds 

635_driver = None # populated if nvJitLink cannot be used 

636_inited = False 

637_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver() 

638  

639# Input type mappings populated by _lazy_init() with C-level enum ints. 

640_nvjitlink_input_types = None 

641_driver_input_types = None 

642  

643  

644def _nvjitlink_has_version_symbol(nvjitlink) -> bool: 

645 # This condition is equivalent to testing for version >= 12.3 

646 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion")) 

647  

648  

649# Note: this function is reused in the tests 

650def _decide_nvjitlink_or_driver() -> bool: 

651 """Return True if falling back to the cuLink* driver APIs.""" 

652 global _driver, _use_nvjitlink_backend 

653 if _use_nvjitlink_backend is not None: 1NZ0MpzAqBrCuljDEFGHIvmwxnyJRWYoPcadefghbi123456789!#$%'()*+,-./:;=

654 return not _use_nvjitlink_backend 1NZ0MpzAqBrCuljDEFGHIvmwxnyJRoPcadefghbi123456789!#$%'()*+,-./:;=

655  

656 warn_txt_common = ( 

657 "the driver APIs will be used instead, which do not support" 1NWY

658 " minor version compatibility or linking LTO IRs." 

659 " For best results, consider upgrading to a recent version of" 

660 ) 

661  

662 nvjitlink_module = _optional_cuda_import( 1NWY

663 "cuda.bindings.nvjitlink", 

664 probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load 1NWY

665 ) 

666 if nvjitlink_module is None: 1NW

667 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 1W

668 else: 

669 from cuda.bindings._internal import nvjitlink 

670  

671 if _nvjitlink_has_version_symbol(nvjitlink): 

672 _use_nvjitlink_backend = True 

673 return False # Use nvjitlink 

674 warn_txt = ( 

675 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)." 

676 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink." 

677 ) 

678  

679 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 1W

680 _use_nvjitlink_backend = False 1W

681 _driver = driver 1W

682 return True 1W

683  

684  

685def _lazy_init(): 

686 global _inited, _nvjitlink_input_types, _driver_input_types 

687 if _inited: 1NOtSTKQLksXUVRoPcadefghbi

688 return 1NOtSTKQLksXUVRoPcadefghbi

689  

690 _decide_nvjitlink_or_driver() 

691 if _use_nvjitlink_backend: 

692 _nvjitlink_input_types = { 

693 "ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX, 

694 "cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN, 

695 "fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN, 

696 "ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR, 

697 "object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT, 

698 "library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY, 

699 } 

700 else: 

701 _driver_input_types = { 

702 "ptx": <int>cydriver.CU_JIT_INPUT_PTX, 

703 "cubin": <int>cydriver.CU_JIT_INPUT_CUBIN, 

704 "fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY, 

705 "object": <int>cydriver.CU_JIT_INPUT_OBJECT, 

706 "library": <int>cydriver.CU_JIT_INPUT_LIBRARY, 

707 } 

708 _inited = True