Coverage for cuda / core / experimental / _program.py: 85%

364 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-10 01:19 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5from __future__ import annotations 

6 

7import weakref 

8from contextlib import contextmanager 

9from dataclasses import dataclass 

10from typing import TYPE_CHECKING, Union 

11from warnings import warn 

12 

13if TYPE_CHECKING: 

14 import cuda.bindings 

15 

16from cuda.core.experimental._device import Device 

17from cuda.core.experimental._linker import Linker, LinkerHandleT, LinkerOptions 

18from cuda.core.experimental._module import ObjectCode 

19from cuda.core.experimental._utils.clear_error_support import assert_type 

20from cuda.core.experimental._utils.cuda_utils import ( 

21 _handle_boolean_option, 

22 check_or_create_options, 

23 driver, 

24 get_binding_version, 

25 handle_return, 

26 is_nested_sequence, 

27 is_sequence, 

28 nvrtc, 

29) 

30 

31 

32@contextmanager 

33def _nvvm_exception_manager(self): 

34 """ 

35 Taken from _linker.py 

36 """ 

37 try: 

38 yield 

39 except Exception as e: 

40 error_log = "" 

41 if hasattr(self, "_mnff"): 

42 try: 

43 nvvm = _get_nvvm_module() 

44 logsize = nvvm.get_program_log_size(self._mnff.handle) 

45 if logsize > 1: 

46 log = bytearray(logsize) 

47 nvvm.get_program_log(self._mnff.handle, log) 

48 error_log = log.decode("utf-8", errors="backslashreplace") 

49 except Exception: 

50 error_log = "" 

51 # Starting Python 3.11 we could also use Exception.add_note() for the same purpose, but 

52 # unfortunately we are still supporting Python 3.10... 

53 e.args = (e.args[0] + (f"\nNVVM program log: {error_log}" if error_log else ""), *e.args[1:]) 

54 raise e 

55 

56 

57_nvvm_module = None 

58_nvvm_import_attempted = False 

59 

60 

61def _get_nvvm_module(): 

62 """ 

63 Handles the import of NVVM module with version and availability checks. 

64 NVVM bindings were added in cuda-bindings 12.9.0, so we need to handle cases where: 

65 1. cuda.bindings is not new enough (< 12.9.0) 

66 2. libnvvm is not found in the Python environment 

67 

68 Returns: 

69 The nvvm module if available and working 

70 

71 Raises: 

72 RuntimeError: If NVVM is not available due to version or library issues 

73 """ 

74 global _nvvm_module, _nvvm_import_attempted 

75 

76 if _nvvm_import_attempted: 

77 if _nvvm_module is None: 

78 raise RuntimeError("NVVM module is not available (previous import attempt failed)") 

79 return _nvvm_module 

80 

81 _nvvm_import_attempted = True 

82 

83 try: 

84 version = get_binding_version() 

85 if version < (12, 9): 

86 raise RuntimeError( 

87 f"NVVM bindings require cuda-bindings >= 12.9.0, but found {version[0]}.{version[1]}.x. " 

88 "Please update cuda-bindings to use NVVM features." 

89 ) 

90 

91 from cuda.bindings import nvvm 

92 from cuda.bindings._internal.nvvm import _inspect_function_pointer 

93 

94 if _inspect_function_pointer("__nvvmCreateProgram") == 0: 

95 raise RuntimeError("NVVM library (libnvvm) is not available in this Python environment. ") 

96 

97 _nvvm_module = nvvm 

98 return _nvvm_module 

99 

100 except RuntimeError as e: 

101 _nvvm_module = None 

102 raise e 

103 

104 

105def _process_define_macro_inner(formatted_options, macro): 

106 if isinstance(macro, str): 

107 formatted_options.append(f"--define-macro={macro}") 

108 return True 

109 if isinstance(macro, tuple): 

110 if len(macro) != 2 or any(not isinstance(val, str) for val in macro): 

111 raise RuntimeError(f"Expected define_macro tuple[str, str], got {macro}") 

112 formatted_options.append(f"--define-macro={macro[0]}={macro[1]}") 

113 return True 

114 return False 

115 

116 

117def _process_define_macro(formatted_options, macro): 

118 union_type = "Union[str, tuple[str, str]]" 

119 if _process_define_macro_inner(formatted_options, macro): 

120 return 

121 if is_nested_sequence(macro): 

122 for seq_macro in macro: 

123 if not _process_define_macro_inner(formatted_options, seq_macro): 

124 raise RuntimeError(f"Expected define_macro {union_type}, got {seq_macro}") 

125 return 

126 raise RuntimeError(f"Expected define_macro {union_type}, list[{union_type}], got {macro}") 

127 

128 

129@dataclass 

130class ProgramOptions: 

131 """Customizable options for configuring `Program`. 

132 

133 Attributes 

134 ---------- 

135 name : str, optional 

136 Name of the program. If the compilation succeeds, the name is passed down to the generated `ObjectCode`. 

137 arch : str, optional 

138 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or 

139 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture 

140 will be used. 

141 relocatable_device_code : bool, optional 

142 Enable (disable) the generation of relocatable device code. 

143 Default: False 

144 extensible_whole_program : bool, optional 

145 Do extensible whole program compilation of device code. 

146 Default: False 

147 debug : bool, optional 

148 Generate debug information. If --dopt is not specified, then turns off all optimizations. 

149 Default: False 

150 lineinfo: bool, optional 

151 Generate line-number information. 

152 Default: False 

153 device_code_optimize : bool, optional 

154 Enable device code optimization. When specified along with ‘-G’, enables limited debug information generation 

155 for optimized device code. 

156 Default: None 

157 ptxas_options : Union[str, list[str]], optional 

158 Specify one or more options directly to ptxas, the PTX optimizing assembler. Options should be strings. 

159 For example ["-v", "-O2"]. 

160 Default: None 

161 max_register_count : int, optional 

162 Specify the maximum amount of registers that GPU functions can use. 

163 Default: None 

164 ftz : bool, optional 

165 When performing single-precision floating-point operations, flush denormal values to zero or preserve denormal 

166 values. 

167 Default: False 

168 prec_sqrt : bool, optional 

169 For single-precision floating-point square root, use IEEE round-to-nearest mode or use a faster approximation. 

170 Default: True 

171 prec_div : bool, optional 

172 For single-precision floating-point division and reciprocals, use IEEE round-to-nearest mode or use a faster 

173 approximation. 

174 Default: True 

175 fma : bool, optional 

176 Enables (disables) the contraction of floating-point multiplies and adds/subtracts into floating-point 

177 multiply-add operations. 

178 Default: True 

179 use_fast_math : bool, optional 

180 Make use of fast math operations. 

181 Default: False 

182 extra_device_vectorization : bool, optional 

183 Enables more aggressive device code vectorization in the NVVM optimizer. 

184 Default: False 

185 link_time_optimization : bool, optional 

186 Generate intermediate code for later link-time optimization. 

187 Default: False 

188 gen_opt_lto : bool, optional 

189 Run the optimizer passes before generating the LTO IR. 

190 Default: False 

191 define_macro : Union[str, tuple[str, str], list[Union[str, tuple[str, str]]]], optional 

192 Predefine a macro. Can be either a string, in which case that macro will be set to 1, a 2 element tuple of 

193 strings, in which case the first element is defined as the second, or a list of strings or tuples. 

194 Default: None 

195 undefine_macro : Union[str, list[str]], optional 

196 Cancel any previous definition of a macro, or list of macros. 

197 Default: None 

198 include_path : Union[str, list[str]], optional 

199 Add the directory or directories to the list of directories to be searched for headers. 

200 Default: None 

201 pre_include : Union[str, list[str]], optional 

202 Preinclude one or more headers during preprocessing. Can be either a string or a list of strings. 

203 Default: None 

204 no_source_include : bool, optional 

205 Disable the default behavior of adding the directory of each input source to the include path. 

206 Default: False 

207 std : str, optional 

208 Set language dialect to C++03, C++11, C++14, C++17 or C++20. 

209 Default: c++17 

210 builtin_move_forward : bool, optional 

211 Provide builtin definitions of std::move and std::forward. 

212 Default: True 

213 builtin_initializer_list : bool, optional 

214 Provide builtin definitions of std::initializer_list class and member functions. 

215 Default: True 

216 disable_warnings : bool, optional 

217 Inhibit all warning messages. 

218 Default: False 

219 restrict : bool, optional 

220 Programmer assertion that all kernel pointer parameters are restrict pointers. 

221 Default: False 

222 device_as_default_execution_space : bool, optional 

223 Treat entities with no execution space annotation as __device__ entities. 

224 Default: False 

225 device_int128 : bool, optional 

226 Allow the __int128 type in device code. 

227 Default: False 

228 optimization_info : str, optional 

229 Provide optimization reports for the specified kind of optimization. 

230 Default: None 

231 no_display_error_number : bool, optional 

232 Disable the display of a diagnostic number for warning messages. 

233 Default: False 

234 diag_error : Union[int, list[int]], optional 

235 Emit error for a specified diagnostic message number or comma separated list of numbers. 

236 Default: None 

237 diag_suppress : Union[int, list[int]], optional 

238 Suppress a specified diagnostic message number or comma separated list of numbers. 

239 Default: None 

240 diag_warn : Union[int, list[int]], optional 

241 Emit warning for a specified diagnostic message number or comma separated lis of numbers. 

242 Default: None 

243 brief_diagnostics : bool, optional 

244 Disable or enable showing source line and column info in a diagnostic. 

245 Default: False 

246 time : str, optional 

247 Generate a CSV table with the time taken by each compilation phase. 

248 Default: None 

249 split_compile : int, optional 

250 Perform compiler optimizations in parallel. 

251 Default: 1 

252 fdevice_syntax_only : bool, optional 

253 Ends device compilation after front-end syntax checking. 

254 Default: False 

255 minimal : bool, optional 

256 Omit certain language features to reduce compile time for small programs. 

257 Default: False 

258 """ 

259 

260 name: str | None = "<default program>" 

261 arch: str | None = None 

262 relocatable_device_code: bool | None = None 

263 extensible_whole_program: bool | None = None 

264 debug: bool | None = None 

265 lineinfo: bool | None = None 

266 device_code_optimize: bool | None = None 

267 ptxas_options: Union[str, list[str], tuple[str]] | None = None 

268 max_register_count: int | None = None 

269 ftz: bool | None = None 

270 prec_sqrt: bool | None = None 

271 prec_div: bool | None = None 

272 fma: bool | None = None 

273 use_fast_math: bool | None = None 

274 extra_device_vectorization: bool | None = None 

275 link_time_optimization: bool | None = None 

276 gen_opt_lto: bool | None = None 

277 define_macro: ( 

278 Union[str, tuple[str, str], list[Union[str, tuple[str, str]]], tuple[Union[str, tuple[str, str]]]] | None 

279 ) = None 

280 undefine_macro: Union[str, list[str], tuple[str]] | None = None 

281 include_path: Union[str, list[str], tuple[str]] | None = None 

282 pre_include: Union[str, list[str], tuple[str]] | None = None 

283 no_source_include: bool | None = None 

284 std: str | None = None 

285 builtin_move_forward: bool | None = None 

286 builtin_initializer_list: bool | None = None 

287 disable_warnings: bool | None = None 

288 restrict: bool | None = None 

289 device_as_default_execution_space: bool | None = None 

290 device_int128: bool | None = None 

291 optimization_info: str | None = None 

292 no_display_error_number: bool | None = None 

293 diag_error: Union[int, list[int], tuple[int]] | None = None 

294 diag_suppress: Union[int, list[int], tuple[int]] | None = None 

295 diag_warn: Union[int, list[int], tuple[int]] | None = None 

296 brief_diagnostics: bool | None = None 

297 time: str | None = None 

298 split_compile: int | None = None 

299 fdevice_syntax_only: bool | None = None 

300 minimal: bool | None = None 

301 numba_debug: bool | None = None # Custom option for Numba debugging 

302 

303 def __post_init__(self): 

304 self._name = self.name.encode() 

305 

306 self._formatted_options = [] 

307 if self.arch is not None: 

308 self._formatted_options.append(f"-arch={self.arch}") 

309 else: 

310 self.arch = f"sm_{Device().arch}" 

311 self._formatted_options.append(f"-arch={self.arch}") 

312 if self.relocatable_device_code is not None: 

313 self._formatted_options.append( 

314 f"--relocatable-device-code={_handle_boolean_option(self.relocatable_device_code)}" 

315 ) 

316 if self.extensible_whole_program is not None and self.extensible_whole_program: 

317 self._formatted_options.append("--extensible-whole-program") 

318 if self.debug is not None and self.debug: 

319 self._formatted_options.append("--device-debug") 

320 if self.lineinfo is not None and self.lineinfo: 

321 self._formatted_options.append("--generate-line-info") 

322 if self.device_code_optimize is not None and self.device_code_optimize: 

323 self._formatted_options.append("--dopt=on") 

324 if self.ptxas_options is not None: 

325 opt_name = "--ptxas-options" 

326 if isinstance(self.ptxas_options, str): 

327 self._formatted_options.append(f"{opt_name}={self.ptxas_options}") 

328 elif is_sequence(self.ptxas_options): 

329 for opt_value in self.ptxas_options: 

330 self._formatted_options.append(f"{opt_name}={opt_value}") 

331 if self.max_register_count is not None: 

332 self._formatted_options.append(f"--maxrregcount={self.max_register_count}") 

333 if self.ftz is not None: 

334 self._formatted_options.append(f"--ftz={_handle_boolean_option(self.ftz)}") 

335 if self.prec_sqrt is not None: 

336 self._formatted_options.append(f"--prec-sqrt={_handle_boolean_option(self.prec_sqrt)}") 

337 if self.prec_div is not None: 

338 self._formatted_options.append(f"--prec-div={_handle_boolean_option(self.prec_div)}") 

339 if self.fma is not None: 

340 self._formatted_options.append(f"--fmad={_handle_boolean_option(self.fma)}") 

341 if self.use_fast_math is not None and self.use_fast_math: 

342 self._formatted_options.append("--use_fast_math") 

343 if self.extra_device_vectorization is not None and self.extra_device_vectorization: 

344 self._formatted_options.append("--extra-device-vectorization") 

345 if self.link_time_optimization is not None and self.link_time_optimization: 

346 self._formatted_options.append("--dlink-time-opt") 

347 if self.gen_opt_lto is not None and self.gen_opt_lto: 

348 self._formatted_options.append("--gen-opt-lto") 

349 if self.define_macro is not None: 

350 _process_define_macro(self._formatted_options, self.define_macro) 

351 if self.undefine_macro is not None: 

352 if isinstance(self.undefine_macro, str): 

353 self._formatted_options.append(f"--undefine-macro={self.undefine_macro}") 

354 elif is_sequence(self.undefine_macro): 

355 for macro in self.undefine_macro: 

356 self._formatted_options.append(f"--undefine-macro={macro}") 

357 if self.include_path is not None: 

358 if isinstance(self.include_path, str): 

359 self._formatted_options.append(f"--include-path={self.include_path}") 

360 elif is_sequence(self.include_path): 

361 for path in self.include_path: 

362 self._formatted_options.append(f"--include-path={path}") 

363 if self.pre_include is not None: 

364 if isinstance(self.pre_include, str): 

365 self._formatted_options.append(f"--pre-include={self.pre_include}") 

366 elif is_sequence(self.pre_include): 

367 for header in self.pre_include: 

368 self._formatted_options.append(f"--pre-include={header}") 

369 

370 if self.no_source_include is not None and self.no_source_include: 

371 self._formatted_options.append("--no-source-include") 

372 if self.std is not None: 

373 self._formatted_options.append(f"--std={self.std}") 

374 if self.builtin_move_forward is not None: 

375 self._formatted_options.append( 

376 f"--builtin-move-forward={_handle_boolean_option(self.builtin_move_forward)}" 

377 ) 

378 if self.builtin_initializer_list is not None: 

379 self._formatted_options.append( 

380 f"--builtin-initializer-list={_handle_boolean_option(self.builtin_initializer_list)}" 

381 ) 

382 if self.disable_warnings is not None and self.disable_warnings: 

383 self._formatted_options.append("--disable-warnings") 

384 if self.restrict is not None and self.restrict: 

385 self._formatted_options.append("--restrict") 

386 if self.device_as_default_execution_space is not None and self.device_as_default_execution_space: 

387 self._formatted_options.append("--device-as-default-execution-space") 

388 if self.device_int128 is not None and self.device_int128: 

389 self._formatted_options.append("--device-int128") 

390 if self.optimization_info is not None: 

391 self._formatted_options.append(f"--optimization-info={self.optimization_info}") 

392 if self.no_display_error_number is not None and self.no_display_error_number: 

393 self._formatted_options.append("--no-display-error-number") 

394 if self.diag_error is not None: 

395 if isinstance(self.diag_error, int): 

396 self._formatted_options.append(f"--diag-error={self.diag_error}") 

397 elif is_sequence(self.diag_error): 

398 for error in self.diag_error: 

399 self._formatted_options.append(f"--diag-error={error}") 

400 if self.diag_suppress is not None: 

401 if isinstance(self.diag_suppress, int): 

402 self._formatted_options.append(f"--diag-suppress={self.diag_suppress}") 

403 elif is_sequence(self.diag_suppress): 

404 for suppress in self.diag_suppress: 

405 self._formatted_options.append(f"--diag-suppress={suppress}") 

406 if self.diag_warn is not None: 

407 if isinstance(self.diag_warn, int): 

408 self._formatted_options.append(f"--diag-warn={self.diag_warn}") 

409 elif is_sequence(self.diag_warn): 

410 for warn in self.diag_warn: 

411 self._formatted_options.append(f"--diag-warn={warn}") 

412 if self.brief_diagnostics is not None: 

413 self._formatted_options.append(f"--brief-diagnostics={_handle_boolean_option(self.brief_diagnostics)}") 

414 if self.time is not None: 

415 self._formatted_options.append(f"--time={self.time}") 

416 if self.split_compile is not None: 

417 self._formatted_options.append(f"--split-compile={self.split_compile}") 

418 if self.fdevice_syntax_only is not None and self.fdevice_syntax_only: 

419 self._formatted_options.append("--fdevice-syntax-only") 

420 if self.minimal is not None and self.minimal: 

421 self._formatted_options.append("--minimal") 

422 if self.numba_debug: 

423 self._formatted_options.append("--numba-debug") 

424 

425 def _as_bytes(self): 

426 # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved 

427 return list(o.encode() for o in self._formatted_options) 

428 

429 def __repr__(self): 

430 # __TODO__ improve this 

431 return str(self._formatted_options) 

432 

433 

434ProgramHandleT = Union["cuda.bindings.nvrtc.nvrtcProgram", LinkerHandleT] 

435 

436 

437class Program: 

438 """Represent a compilation machinery to process programs into 

439 :obj:`~_module.ObjectCode`. 

440 

441 This object provides a unified interface to multiple underlying 

442 compiler libraries. Compilation support is enabled for a wide 

443 range of code types and compilation types. 

444 

445 Parameters 

446 ---------- 

447 code : Any 

448 String of the CUDA Runtime Compilation program. 

449 code_type : Any 

450 String of the code type. Currently ``"ptx"``, ``"c++"``, and ``"nvvm"`` are supported. 

451 options : ProgramOptions, optional 

452 A ProgramOptions object to customize the compilation process. 

453 See :obj:`ProgramOptions` for more information. 

454 """ 

455 

456 class _MembersNeededForFinalize: 

457 __slots__ = "handle", "backend" 

458 

459 def __init__(self, program_obj, handle, backend): 

460 self.handle = handle 

461 self.backend = backend 

462 weakref.finalize(program_obj, self.close) 

463 

464 def close(self): 

465 if self.handle is not None: 

466 if self.backend == "NVRTC": 

467 handle_return(nvrtc.nvrtcDestroyProgram(self.handle)) 

468 elif self.backend == "NVVM": 

469 nvvm = _get_nvvm_module() 

470 nvvm.destroy_program(self.handle) 

471 self.handle = None 

472 

473 __slots__ = ("__weakref__", "_mnff", "_backend", "_linker", "_options") 

474 

475 def __init__(self, code, code_type, options: ProgramOptions = None): 

476 self._mnff = Program._MembersNeededForFinalize(self, None, None) 

477 

478 self._options = options = check_or_create_options(ProgramOptions, options, "Program options") 

479 code_type = code_type.lower() 

480 

481 if code_type == "c++": 

482 assert_type(code, str) 

483 # TODO: support pre-loaded headers & include names 

484 # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved 

485 

486 self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], [])) 

487 self._mnff.backend = "NVRTC" 

488 self._backend = "NVRTC" 

489 self._linker = None 

490 

491 elif code_type == "ptx": 

492 assert_type(code, str) 

493 self._linker = Linker( 

494 ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options) 

495 ) 

496 self._backend = self._linker.backend 

497 

498 elif code_type == "nvvm": 

499 if isinstance(code, str): 

500 code = code.encode("utf-8") 

501 elif not isinstance(code, (bytes, bytearray)): 

502 raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray") 

503 

504 nvvm = _get_nvvm_module() 

505 self._mnff.handle = nvvm.create_program() 

506 self._mnff.backend = "NVVM" 

507 nvvm.add_module_to_program(self._mnff.handle, code, len(code), options._name.decode()) 

508 self._backend = "NVVM" 

509 self._linker = None 

510 

511 else: 

512 supported_code_types = ("c++", "ptx", "nvvm") 

513 assert code_type not in supported_code_types, f"{code_type=}" 

514 raise RuntimeError(f"Unsupported {code_type=} ({supported_code_types=})") 

515 

516 def _translate_program_options(self, options: ProgramOptions) -> LinkerOptions: 

517 return LinkerOptions( 

518 name=options.name, 

519 arch=options.arch, 

520 max_register_count=options.max_register_count, 

521 time=options.time, 

522 debug=options.debug, 

523 lineinfo=options.lineinfo, 

524 ftz=options.ftz, 

525 prec_div=options.prec_div, 

526 prec_sqrt=options.prec_sqrt, 

527 fma=options.fma, 

528 link_time_optimization=options.link_time_optimization, 

529 split_compile=options.split_compile, 

530 ptxas_options=options.ptxas_options, 

531 ) 

532 

533 def _translate_program_options_to_nvvm(self, options: ProgramOptions) -> list[str]: 

534 """Translate ProgramOptions to NVVM-specific compilation options.""" 

535 nvvm_options = [] 

536 

537 assert options.arch is not None 

538 arch = options.arch 

539 if arch.startswith("sm_"): 

540 arch = f"compute_{arch[3:]}" 

541 nvvm_options.append(f"-arch={arch}") 

542 if options.debug: 

543 nvvm_options.append("-g") 

544 if options.device_code_optimize is False: 

545 nvvm_options.append("-opt=0") 

546 elif options.device_code_optimize is True: 

547 nvvm_options.append("-opt=3") 

548 # NVVM is not consistent with NVRTC, it uses 0/1 instead... 

549 if options.ftz is not None: 

550 nvvm_options.append(f"-ftz={'1' if options.ftz else '0'}") 

551 if options.prec_sqrt is not None: 

552 nvvm_options.append(f"-prec-sqrt={'1' if options.prec_sqrt else '0'}") 

553 if options.prec_div is not None: 

554 nvvm_options.append(f"-prec-div={'1' if options.prec_div else '0'}") 

555 if options.fma is not None: 

556 nvvm_options.append(f"-fma={'1' if options.fma else '0'}") 

557 

558 return nvvm_options 

559 

560 def close(self): 

561 """Destroy this program.""" 

562 if self._linker: 

563 self._linker.close() 

564 self._mnff.close() 

565 

566 @staticmethod 

567 def _can_load_generated_ptx(): 

568 driver_ver = handle_return(driver.cuDriverGetVersion()) 

569 nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) 

570 return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver 

571 

572 def compile(self, target_type, name_expressions=(), logs=None): 

573 """Compile the program with a specific compilation type. 

574 

575 Parameters 

576 ---------- 

577 target_type : Any 

578 String of the targeted compilation type. 

579 Supported options are "ptx", "cubin" and "ltoir". 

580 name_expressions : Union[list, tuple], optional 

581 List of explicit name expressions to become accessible. 

582 (Default to no expressions) 

583 logs : Any, optional 

584 Object with a write method to receive the logs generated 

585 from compilation. 

586 (Default to no logs) 

587 

588 Returns 

589 ------- 

590 :obj:`~_module.ObjectCode` 

591 Newly created code object. 

592 

593 """ 

594 supported_target_types = ("ptx", "cubin", "ltoir") 

595 if target_type not in supported_target_types: 

596 raise ValueError(f'Unsupported target_type="{target_type}" ({supported_target_types=})') 

597 

598 if self._backend == "NVRTC": 

599 if target_type == "ptx" and not self._can_load_generated_ptx(): 

600 warn( 

601 "The CUDA driver version is older than the backend version. " 

602 "The generated ptx will not be loadable by the current driver.", 

603 stacklevel=1, 

604 category=RuntimeWarning, 

605 ) 

606 if name_expressions: 

607 for n in name_expressions: 

608 handle_return( 

609 nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), 

610 handle=self._mnff.handle, 

611 ) 

612 options = self._options._as_bytes() 

613 handle_return( 

614 nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), 

615 handle=self._mnff.handle, 

616 ) 

617 

618 size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size") 

619 comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}") 

620 size = handle_return(size_func(self._mnff.handle), handle=self._mnff.handle) 

621 data = b" " * size 

622 handle_return(comp_func(self._mnff.handle, data), handle=self._mnff.handle) 

623 

624 symbol_mapping = {} 

625 if name_expressions: 

626 for n in name_expressions: 

627 symbol_mapping[n] = handle_return( 

628 nvrtc.nvrtcGetLoweredName(self._mnff.handle, n.encode()), handle=self._mnff.handle 

629 ) 

630 

631 if logs is not None: 

632 logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._mnff.handle), handle=self._mnff.handle) 

633 if logsize > 1: 

634 log = b" " * logsize 

635 handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle) 

636 logs.write(log.decode("utf-8", errors="backslashreplace")) 

637 

638 return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name) 

639 

640 elif self._backend == "NVVM": 

641 if target_type not in ("ptx", "ltoir"): 

642 raise ValueError(f'NVVM backend only supports target_type="ptx", "ltoir", got "{target_type}"') 

643 

644 nvvm_options = self._translate_program_options_to_nvvm(self._options) 

645 if target_type == "ltoir" and "-gen-lto" not in nvvm_options: 

646 nvvm_options.append("-gen-lto") 

647 nvvm = _get_nvvm_module() 

648 with _nvvm_exception_manager(self): 

649 nvvm.verify_program(self._mnff.handle, len(nvvm_options), nvvm_options) 

650 nvvm.compile_program(self._mnff.handle, len(nvvm_options), nvvm_options) 

651 

652 size = nvvm.get_compiled_result_size(self._mnff.handle) 

653 data = bytearray(size) 

654 nvvm.get_compiled_result(self._mnff.handle, data) 

655 

656 if logs is not None: 

657 logsize = nvvm.get_program_log_size(self._mnff.handle) 

658 if logsize > 1: 

659 log = bytearray(logsize) 

660 nvvm.get_program_log(self._mnff.handle, log) 

661 logs.write(log.decode("utf-8", errors="backslashreplace")) 

662 

663 return ObjectCode._init(data, target_type, name=self._options.name) 

664 

665 supported_backends = ("nvJitLink", "driver") 

666 if self._backend not in supported_backends: 

667 raise ValueError(f'Unsupported backend="{self._backend}" ({supported_backends=})') 

668 return self._linker.link(target_type) 

669 

670 @property 

671 def backend(self) -> str: 

672 """Return this Program instance's underlying backend.""" 

673 return self._backend 

674 

675 @property 

676 def handle(self) -> ProgramHandleT: 

677 """Return the underlying handle object. 

678 

679 .. note:: 

680 

681 The type of the returned object depends on the backend. 

682 

683 .. caution:: 

684 

685 This handle is a Python object. To get the memory address of the underlying C 

686 handle, call ``int(Program.handle)``. 

687 """ 

688 return self._mnff.handle