Coverage for cuda / core / experimental / _linker.py: 70%

292 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-10 01:19 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5from __future__ import annotations 

6 

7import ctypes 

8import sys 

9import weakref 

10from contextlib import contextmanager 

11from dataclasses import dataclass 

12from typing import TYPE_CHECKING, Union 

13from warnings import warn 

14 

15if TYPE_CHECKING: 

16 import cuda.bindings 

17 

18from cuda.core.experimental._device import Device 

19from cuda.core.experimental._module import ObjectCode 

20from cuda.core.experimental._utils.clear_error_support import assert_type 

21from cuda.core.experimental._utils.cuda_utils import check_or_create_options, driver, handle_return, is_sequence 

22 

23# TODO: revisit this treatment for py313t builds 

24_driver = None # populated if nvJitLink cannot be used 

25_driver_input_types = None # populated if nvJitLink cannot be used 

26_driver_ver = None 

27_inited = False 

28_nvjitlink = None # populated if nvJitLink can be used 

29_nvjitlink_input_types = None # populated if nvJitLink cannot be used 

30 

31 

32def _nvjitlink_has_version_symbol(inner_nvjitlink) -> bool: 

33 # This condition is equivalent to testing for version >= 12.3 

34 return bool(inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion")) 

35 

36 

37# Note: this function is reused in the tests 

38def _decide_nvjitlink_or_driver() -> bool: 

39 """Returns True if falling back to the cuLink* driver APIs.""" 

40 global _driver_ver, _driver, _nvjitlink 

41 if _driver or _nvjitlink: 

42 return _driver is not None 

43 

44 _driver_ver = handle_return(driver.cuDriverGetVersion()) 

45 _driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10) 

46 

47 warn_txt_common = ( 

48 "the driver APIs will be used instead, which do not support" 

49 " minor version compatibility or linking LTO IRs." 

50 " For best results, consider upgrading to a recent version of" 

51 ) 

52 

53 try: 

54 import cuda.bindings.nvjitlink as _nvjitlink 

55 except ModuleNotFoundError: 

56 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 

57 else: 

58 from cuda.bindings._internal import nvjitlink as inner_nvjitlink 

59 

60 try: 

61 if _nvjitlink_has_version_symbol(inner_nvjitlink): 

62 return False # Use nvjitlink 

63 except RuntimeError: 

64 warn_detail = "not available" 

65 else: 

66 warn_detail = "too old (<12.3)" 

67 warn_txt = ( 

68 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is {warn_detail}." 

69 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink." 

70 ) 

71 _nvjitlink = None 

72 

73 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 

74 _driver = driver 

75 return True 

76 

77 

78def _lazy_init(): 

79 global _inited, _nvjitlink_input_types, _driver_input_types 

80 if _inited: 

81 return 

82 

83 _decide_nvjitlink_or_driver() 

84 if _nvjitlink: 

85 if _driver_ver > _nvjitlink.version(): 

86 # TODO: nvJitLink is not new enough, warn? 

87 pass 

88 _nvjitlink_input_types = { 

89 "ptx": _nvjitlink.InputType.PTX, 

90 "cubin": _nvjitlink.InputType.CUBIN, 

91 "fatbin": _nvjitlink.InputType.FATBIN, 

92 "ltoir": _nvjitlink.InputType.LTOIR, 

93 "object": _nvjitlink.InputType.OBJECT, 

94 "library": _nvjitlink.InputType.LIBRARY, 

95 } 

96 else: 

97 _driver_input_types = { 

98 "ptx": _driver.CUjitInputType.CU_JIT_INPUT_PTX, 

99 "cubin": _driver.CUjitInputType.CU_JIT_INPUT_CUBIN, 

100 "fatbin": _driver.CUjitInputType.CU_JIT_INPUT_FATBINARY, 

101 "object": _driver.CUjitInputType.CU_JIT_INPUT_OBJECT, 

102 "library": _driver.CUjitInputType.CU_JIT_INPUT_LIBRARY, 

103 } 

104 _inited = True 

105 

106 

107@dataclass 

108class LinkerOptions: 

109 """Customizable :obj:`Linker` options. 

110 

111 Since the linker would choose to use nvJitLink or the driver APIs as the linking backed, 

112 not all options are applicable. When the system's installed nvJitLink is too old (<12.3), 

113 or not installed, the driver APIs (cuLink) will be used instead. 

114 

115 Attributes 

116 ---------- 

117 name : str, optional 

118 Name of the linker. If the linking succeeds, the name is passed down to the generated `ObjectCode`. 

119 arch : str, optional 

120 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or 

121 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture 

122 will be used. 

123 max_register_count : int, optional 

124 Maximum register count. 

125 time : bool, optional 

126 Print timing information to the info log. 

127 Default: False. 

128 verbose : bool, optional 

129 Print verbose messages to the info log. 

130 Default: False. 

131 link_time_optimization : bool, optional 

132 Perform link time optimization. 

133 Default: False. 

134 ptx : bool, optional 

135 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``. 

136 Default: False. 

137 optimization_level : int, optional 

138 Set optimization level. Only 0 and 3 are accepted. 

139 debug : bool, optional 

140 Generate debug information. 

141 Default: False. 

142 lineinfo : bool, optional 

143 Generate line information. 

144 Default: False. 

145 ftz : bool, optional 

146 Flush denormal values to zero. 

147 Default: False. 

148 prec_div : bool, optional 

149 Use precise division. 

150 Default: True. 

151 prec_sqrt : bool, optional 

152 Use precise square root. 

153 Default: True. 

154 fma : bool, optional 

155 Use fast multiply-add. 

156 Default: True. 

157 kernels_used : [Union[str, tuple[str], list[str]]], optional 

158 Pass a kernel or sequence of kernels that are used; any not in the list can be removed. 

159 variables_used : [Union[str, tuple[str], list[str]]], optional 

160 Pass a variable or sequence of variables that are used; any not in the list can be removed. 

161 optimize_unused_variables : bool, optional 

162 Assume that if a variable is not referenced in device code, it can be removed. 

163 Default: False. 

164 ptxas_options : [Union[str, tuple[str], list[str]]], optional 

165 Pass options to PTXAS. 

166 split_compile : int, optional 

167 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split 

168 compilation (default). 

169 Default: 1. 

170 split_compile_extended : int, optional 

171 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value. 

172 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This 

173 option can potentially impact performance of the compiled binary. 

174 Default: 1. 

175 no_cache : bool, optional 

176 Do not cache the intermediate steps of nvJitLink. 

177 Default: False. 

178 """ 

179 

180 name: str | None = "<default linker>" 

181 arch: str | None = None 

182 max_register_count: int | None = None 

183 time: bool | None = None 

184 verbose: bool | None = None 

185 link_time_optimization: bool | None = None 

186 ptx: bool | None = None 

187 optimization_level: int | None = None 

188 debug: bool | None = None 

189 lineinfo: bool | None = None 

190 ftz: bool | None = None 

191 prec_div: bool | None = None 

192 prec_sqrt: bool | None = None 

193 fma: bool | None = None 

194 kernels_used: Union[str, tuple[str], list[str]] | None = None 

195 variables_used: Union[str, tuple[str], list[str]] | None = None 

196 optimize_unused_variables: bool | None = None 

197 ptxas_options: Union[str, tuple[str], list[str]] | None = None 

198 split_compile: int | None = None 

199 split_compile_extended: int | None = None 

200 no_cache: bool | None = None 

201 

202 def __post_init__(self): 

203 _lazy_init() 

204 self._name = self.name.encode() 

205 self.formatted_options = [] 

206 if _nvjitlink: 

207 self._init_nvjitlink() 

208 else: 

209 self._init_driver() 

210 

211 def _init_nvjitlink(self): 

212 if self.arch is not None: 

213 self.formatted_options.append(f"-arch={self.arch}") 

214 else: 

215 self.formatted_options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 

216 if self.max_register_count is not None: 

217 self.formatted_options.append(f"-maxrregcount={self.max_register_count}") 

218 if self.time is not None: 

219 self.formatted_options.append("-time") 

220 if self.verbose: 

221 self.formatted_options.append("-verbose") 

222 if self.link_time_optimization: 

223 self.formatted_options.append("-lto") 

224 if self.ptx: 

225 self.formatted_options.append("-ptx") 

226 if self.optimization_level is not None: 

227 self.formatted_options.append(f"-O{self.optimization_level}") 

228 if self.debug: 

229 self.formatted_options.append("-g") 

230 if self.lineinfo: 

231 self.formatted_options.append("-lineinfo") 

232 if self.ftz is not None: 

233 self.formatted_options.append(f"-ftz={'true' if self.ftz else 'false'}") 

234 if self.prec_div is not None: 

235 self.formatted_options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 

236 if self.prec_sqrt is not None: 

237 self.formatted_options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 

238 if self.fma is not None: 

239 self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}") 

240 if self.kernels_used is not None: 

241 if isinstance(self.kernels_used, str): 

242 self.formatted_options.append(f"-kernels-used={self.kernels_used}") 

243 elif isinstance(self.kernels_used, list): 

244 for kernel in self.kernels_used: 

245 self.formatted_options.append(f"-kernels-used={kernel}") 

246 if self.variables_used is not None: 

247 if isinstance(self.variables_used, str): 

248 self.formatted_options.append(f"-variables-used={self.variables_used}") 

249 elif isinstance(self.variables_used, list): 

250 for variable in self.variables_used: 

251 self.formatted_options.append(f"-variables-used={variable}") 

252 if self.optimize_unused_variables is not None: 

253 self.formatted_options.append("-optimize-unused-variables") 

254 if self.ptxas_options is not None: 

255 if isinstance(self.ptxas_options, str): 

256 self.formatted_options.append(f"-Xptxas={self.ptxas_options}") 

257 elif is_sequence(self.ptxas_options): 

258 for opt in self.ptxas_options: 

259 self.formatted_options.append(f"-Xptxas={opt}") 

260 if self.split_compile is not None: 

261 self.formatted_options.append(f"-split-compile={self.split_compile}") 

262 if self.split_compile_extended is not None: 

263 self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}") 

264 if self.no_cache is True: 

265 self.formatted_options.append("-no-cache") 

266 

267 def _init_driver(self): 

268 self.option_keys = [] 

269 # allocate 4 KiB each for info/error logs 

270 size = 4194304 

271 self.formatted_options.extend((bytearray(size), size, bytearray(size), size)) 

272 self.option_keys.extend( 

273 ( 

274 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER, 

275 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, 

276 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER, 

277 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, 

278 ) 

279 ) 

280 

281 if self.arch is not None: 

282 arch = self.arch.split("_")[-1].upper() 

283 self.formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}")) 

284 self.option_keys.append(_driver.CUjit_option.CU_JIT_TARGET) 

285 if self.max_register_count is not None: 

286 self.formatted_options.append(self.max_register_count) 

287 self.option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS) 

288 if self.time is not None: 

289 raise ValueError("time option is not supported by the driver API") 

290 if self.verbose: 

291 self.formatted_options.append(1) 

292 self.option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE) 

293 if self.link_time_optimization: 

294 self.formatted_options.append(1) 

295 self.option_keys.append(_driver.CUjit_option.CU_JIT_LTO) 

296 if self.ptx: 

297 raise ValueError("ptx option is not supported by the driver API") 

298 if self.optimization_level is not None: 

299 self.formatted_options.append(self.optimization_level) 

300 self.option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL) 

301 if self.debug: 

302 self.formatted_options.append(1) 

303 self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO) 

304 if self.lineinfo: 

305 self.formatted_options.append(1) 

306 self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO) 

307 if self.ftz is not None: 

308 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

309 if self.prec_div is not None: 

310 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

311 if self.prec_sqrt is not None: 

312 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

313 if self.fma is not None: 

314 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

315 if self.kernels_used is not None: 

316 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

317 if self.variables_used is not None: 

318 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

319 if self.optimize_unused_variables is not None: 

320 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

321 if self.ptxas_options is not None: 

322 raise ValueError("ptxas_options option is not supported by the driver API") 

323 if self.split_compile is not None: 

324 raise ValueError("split_compile option is not supported by the driver API") 

325 if self.split_compile_extended is not None: 

326 raise ValueError("split_compile_extended option is not supported by the driver API") 

327 if self.no_cache is True: 

328 self.formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE) 

329 self.option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE) 

330 

331 

332# This needs to be a free function not a method, as it's disallowed by contextmanager. 

333@contextmanager 

334def _exception_manager(self): 

335 """ 

336 A helper function to improve the error message of exceptions raised by the linker backend. 

337 """ 

338 try: 

339 yield 

340 except Exception as e: 

341 error_log = "" 

342 if hasattr(self, "_mnff"): 

343 # our constructor could raise, in which case there's no handle available 

344 error_log = self.get_error_log() 

345 # Starting Python 3.11 we could also use Exception.add_note() for the same purpose, but 

346 # unfortunately we are still supporting Python 3.10... 

347 # Here we rely on both CUDAError and nvJitLinkError have the error string placed in .args[0]. 

348 e.args = (e.args[0] + (f"\nLinker error log: {error_log}" if error_log else ""), *e.args[1:]) 

349 raise e 

350 

351 

352nvJitLinkHandleT = int 

353LinkerHandleT = Union[nvJitLinkHandleT, "cuda.bindings.driver.CUlinkState"] 

354 

355 

356class Linker: 

357 """Represent a linking machinery to link one or multiple object codes into 

358 :obj:`~cuda.core.experimental._module.ObjectCode` with the specified options. 

359 

360 This object provides a unified interface to multiple underlying 

361 linker libraries (such as nvJitLink or cuLink* from CUDA driver). 

362 

363 Parameters 

364 ---------- 

365 object_codes : ObjectCode 

366 One or more ObjectCode objects to be linked. 

367 options : LinkerOptions, optional 

368 Options for the linker. If not provided, default options will be used. 

369 """ 

370 

371 class _MembersNeededForFinalize: 

372 __slots__ = ("handle", "use_nvjitlink", "const_char_keep_alive") 

373 

374 def __init__(self, program_obj, handle, use_nvjitlink): 

375 self.handle = handle 

376 self.use_nvjitlink = use_nvjitlink 

377 self.const_char_keep_alive = [] 

378 weakref.finalize(program_obj, self.close) 

379 

380 def close(self): 

381 if self.handle is not None: 

382 if self.use_nvjitlink: 

383 _nvjitlink.destroy(self.handle) 

384 else: 

385 handle_return(_driver.cuLinkDestroy(self.handle)) 

386 self.handle = None 

387 

388 __slots__ = ("__weakref__", "_mnff", "_options") 

389 

390 def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None): 

391 if len(object_codes) == 0: 

392 raise ValueError("At least one ObjectCode object must be provided") 

393 

394 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 

395 with _exception_manager(self): 

396 if _nvjitlink: 

397 handle = _nvjitlink.create(len(options.formatted_options), options.formatted_options) 

398 use_nvjitlink = True 

399 else: 

400 handle = handle_return( 

401 _driver.cuLinkCreate(len(options.formatted_options), options.option_keys, options.formatted_options) 

402 ) 

403 use_nvjitlink = False 

404 self._mnff = Linker._MembersNeededForFinalize(self, handle, use_nvjitlink) 

405 

406 for code in object_codes: 

407 assert_type(code, ObjectCode) 

408 self._add_code_object(code) 

409 

410 def _add_code_object(self, object_code: ObjectCode): 

411 data = object_code._module 

412 with _exception_manager(self): 

413 name_str = f"{object_code.name}" 

414 if _nvjitlink and isinstance(data, bytes): 

415 _nvjitlink.add_data( 

416 self._mnff.handle, 

417 self._input_type_from_code_type(object_code._code_type), 

418 data, 

419 len(data), 

420 name_str, 

421 ) 

422 elif _nvjitlink and isinstance(data, str): 

423 _nvjitlink.add_file( 

424 self._mnff.handle, 

425 self._input_type_from_code_type(object_code._code_type), 

426 data, 

427 ) 

428 elif (not _nvjitlink) and isinstance(data, bytes): 

429 name_bytes = name_str.encode() 

430 handle_return( 

431 _driver.cuLinkAddData( 

432 self._mnff.handle, 

433 self._input_type_from_code_type(object_code._code_type), 

434 data, 

435 len(data), 

436 name_bytes, 

437 0, 

438 None, 

439 None, 

440 ) 

441 ) 

442 self._mnff.const_char_keep_alive.append(name_bytes) 

443 elif (not _nvjitlink) and isinstance(data, str): 

444 name_bytes = name_str.encode() 

445 handle_return( 

446 _driver.cuLinkAddFile( 

447 self._mnff.handle, 

448 self._input_type_from_code_type(object_code._code_type), 

449 data.encode(), 

450 0, 

451 None, 

452 None, 

453 ) 

454 ) 

455 self._mnff.const_char_keep_alive.append(name_bytes) 

456 else: 

457 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") 

458 

459 def link(self, target_type) -> ObjectCode: 

460 """ 

461 Links the provided object codes into a single output of the specified target type. 

462 

463 Parameters 

464 ---------- 

465 target_type : str 

466 The type of the target output. Must be either "cubin" or "ptx". 

467 

468 Returns 

469 ------- 

470 ObjectCode 

471 The linked object code of the specified target type. 

472 

473 Note 

474 ------ 

475 See nvrtc compiler options documnetation to ensure the input object codes are 

476 correctly compiled for linking. 

477 """ 

478 if target_type not in ("cubin", "ptx"): 

479 raise ValueError(f"Unsupported target type: {target_type}") 

480 with _exception_manager(self): 

481 if _nvjitlink: 

482 _nvjitlink.complete(self._mnff.handle) 

483 if target_type == "cubin": 

484 get_size = _nvjitlink.get_linked_cubin_size 

485 get_code = _nvjitlink.get_linked_cubin 

486 else: 

487 get_size = _nvjitlink.get_linked_ptx_size 

488 get_code = _nvjitlink.get_linked_ptx 

489 size = get_size(self._mnff.handle) 

490 code = bytearray(size) 

491 get_code(self._mnff.handle, code) 

492 else: 

493 addr, size = handle_return(_driver.cuLinkComplete(self._mnff.handle)) 

494 code = (ctypes.c_char * size).from_address(addr) 

495 

496 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 

497 

498 def get_error_log(self) -> str: 

499 """Get the error log generated by the linker. 

500 

501 Returns 

502 ------- 

503 str 

504 The error log. 

505 """ 

506 if _nvjitlink: 

507 log_size = _nvjitlink.get_error_log_size(self._mnff.handle) 

508 log = bytearray(log_size) 

509 _nvjitlink.get_error_log(self._mnff.handle, log) 

510 else: 

511 log = self._options.formatted_options[2] 

512 return log.decode("utf-8", errors="backslashreplace") 

513 

514 def get_info_log(self) -> str: 

515 """Get the info log generated by the linker. 

516 

517 Returns 

518 ------- 

519 str 

520 The info log. 

521 """ 

522 if _nvjitlink: 

523 log_size = _nvjitlink.get_info_log_size(self._mnff.handle) 

524 log = bytearray(log_size) 

525 _nvjitlink.get_info_log(self._mnff.handle, log) 

526 else: 

527 log = self._options.formatted_options[0] 

528 return log.decode("utf-8", errors="backslashreplace") 

529 

530 def _input_type_from_code_type(self, code_type: str): 

531 # this list is based on the supported values for code_type in the ObjectCode class definition. 

532 # nvJitLink/driver support other options for input type 

533 input_type = _nvjitlink_input_types.get(code_type) if _nvjitlink else _driver_input_types.get(code_type) 

534 

535 if input_type is None: 

536 raise ValueError(f"Unknown code_type associated with ObjectCode: {code_type}") 

537 return input_type 

538 

539 @property 

540 def handle(self) -> LinkerHandleT: 

541 """Return the underlying handle object. 

542 

543 .. note:: 

544 

545 The type of the returned object depends on the backend. 

546 

547 .. caution:: 

548 

549 This handle is a Python object. To get the memory address of the underlying C 

550 handle, call ``int(Linker.handle)``. 

551 """ 

552 return self._mnff.handle 

553 

554 @property 

555 def backend(self) -> str: 

556 """Return this Linker instance's underlying backend.""" 

557 return "nvJitLink" if self._mnff.use_nvjitlink else "driver" 

558 

559 def close(self): 

560 """Destroy this linker.""" 

561 self._mnff.close()