Coverage for cuda/core/_linker.pyx: 81.10%

365 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-13 01:38 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4"""Linking machinery for combining object codes. 

5  

6This module provides :class:`Linker` for linking one or more 

7:class:`~cuda.core.ObjectCode` objects, with :class:`LinkerOptions` for 

8configuration. 

9""" 

10  

11from __future__ import annotations 

12  

13from cpython.bytearray cimport PyByteArray_AS_STRING 

14from libc.stdint cimport intptr_t, uint32_t 

15from libcpp.vector cimport vector 

16from cuda.bindings cimport cydriver 

17from cuda.bindings cimport cynvjitlink 

18  

19from ._resource_handles cimport ( 

20 as_cu, 

21 as_py, 

22 create_culink_handle, 

23 create_nvjitlink_handle, 

24) 

25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, HANDLE_RETURN_NVJITLINK 

26  

27import sys 

28from dataclasses import dataclass 

29from typing import TYPE_CHECKING, Union 

30from warnings import warn 

31  

32from cuda.pathfinder._optional_cuda_import import _optional_cuda_import 

33from cuda.core._device import Device 

34from cuda.core._module import ObjectCode 

35from cuda.core._utils.clear_error_support import assert_type 

36from cuda.core._utils.cuda_utils import ( 

37 CUDAError, 

38 check_or_create_options, 

39 driver, 

40 is_sequence, 

41) 

42from cuda.core.typing import CompilerBackendType, ObjectCodeFormatType 

43  

44if TYPE_CHECKING: 

45 import cuda.bindings.driver # no-cython-lint 

46 import cuda.bindings.nvjitlink # no-cython-lint 

47  

48# Module-level annotations to ensure stubgen-pyx keeps the above imports in 

49# the generated `.pyi` so that the LinkerHandleT forward references resolve. 

50# These names are not assigned, so they only affect __annotations__. 

51_keep_driver_in_stub: "cuda.bindings.driver.CUlinkState" 

52_keep_nvjitlink_in_stub: "cuda.bindings.nvjitlink.nvJitLinkHandle" 

53  

54ctypedef const char* const_char_ptr 

55ctypedef void* void_ptr 

56  

57__all__ = ["Linker", "LinkerOptions"] 

58  

59LinkerHandleT = Union["cuda.bindings.nvjitlink.nvJitLinkHandle", "cuda.bindings.driver.CUlinkState"] 

60  

61  

62# ============================================================================= 

63# Principal class 

64# ============================================================================= 

65  

66cdef class Linker: 

67 """Represent a linking machinery to link one or more object codes into 

68 :class:`~cuda.core.ObjectCode`. 

69  

70 This object provides a unified interface to multiple underlying 

71 linker libraries (such as nvJitLink or cuLink* from the CUDA driver). 

72  

73 Parameters 

74 ---------- 

75 object_codes : :class:`~cuda.core.ObjectCode` 

76 One or more ObjectCode objects to be linked. 

77 options : :class:`LinkerOptions`, optional 

78 Options for the linker. If not provided, default options will be used. 

79 """ 

80  

81 def __init__(self, *object_codes: ObjectCode, options: LinkerOptions | None = None): 

82 Linker_init(self, object_codes, options) 1$OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

83  

84 def link(self, target_type: ObjectCodeFormatType | str) -> ObjectCode: 

85 """Link the provided object codes into a single output of the specified target type. 

86  

87 Parameters 

88 ---------- 

89 target_type : ObjectCodeFormatType | str 

90 The type of the target output. Must be either "cubin" or "ptx". 

91  

92 Returns 

93 ------- 

94 :class:`~cuda.core.ObjectCode` 

95 The linked object code of the specified target type. 

96  

97 .. note:: 

98  

99 Ensure that input object codes were compiled with appropriate 

100 flags for linking (e.g., relocatable device code enabled). 

101 """ 

102 return Linker_link(self, str(target_type)) 1OtMpzAqBrCumkDEFGHIvnwxoyJKQLlscdefghiabj

103  

104 def get_error_log(self) -> str: 

105 """Get the error log generated by the linker. 

106  

107 Returns 

108 ------- 

109 str 

110 The error log. 

111 """ 

112 # After link(), the decoded log is cached here. 

113 if self._error_log is not None: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

114 return self._error_log 1s

115 cdef cynvjitlink.nvJitLinkHandle c_h 

116 cdef size_t c_log_size = 0 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

117 cdef char* c_log_ptr 

118 if self._use_nvjitlink: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

119 c_h = as_cu(self._nvjitlink_handle) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

120 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetErrorLogSize(c_h, &c_log_size)) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

121 log = bytearray(c_log_size) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

122 if c_log_size > 0: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

123 c_log_ptr = <char*>(<bytearray>log) 1O

124 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetErrorLog(c_h, c_log_ptr)) 1O

125 return log.decode("utf-8", errors="backslashreplace") 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

126 else: 

127 return (<bytearray>self._drv_log_bufs[2]).decode( 

128 "utf-8", errors="backslashreplace").rstrip('\x00') 

129  

130 def get_info_log(self) -> str: 

131 """Get the info log generated by the linker. 

132  

133 Returns 

134 ------- 

135 str 

136 The info log. 

137 """ 

138 # After link(), the decoded log is cached here. 

139 if self._info_log is not None: 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

140 return self._info_log 1ts

141 cdef cynvjitlink.nvJitLinkHandle c_h 

142 cdef size_t c_log_size = 0 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

143 cdef char* c_log_ptr 

144 if self._use_nvjitlink: 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

145 c_h = as_cu(self._nvjitlink_handle) 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

146 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetInfoLogSize(c_h, &c_log_size)) 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

147 log = bytearray(c_log_size) 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

148 if c_log_size > 0: 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

149 c_log_ptr = <char*>(<bytearray>log) 1pqrklab

150 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetInfoLog(c_h, c_log_ptr)) 1pqrklab

151 return log.decode("utf-8", errors="backslashreplace") 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

152 else: 

153 return (<bytearray>self._drv_log_bufs[0]).decode( 

154 "utf-8", errors="backslashreplace").rstrip('\x00') 

155  

156 def close(self) -> None: 

157 """Destroy this linker.""" 

158 if self._use_nvjitlink: 1Pcdefghiabj

159 self._nvjitlink_handle.reset() 1Pcdefghiabj

160 else: 

161 self._culink_handle.reset() 

162  

163 @property 

164 def handle(self) -> LinkerHandleT: 

165 """Return the underlying handle object. 

166  

167 .. note:: 

168  

169 The type of the returned object depends on the backend. 

170  

171 .. caution:: 

172  

173 This handle is a Python object. To get the memory address of the underlying C 

174 handle, call ``int(Linker.handle)``. 

175 """ 

176 if self._use_nvjitlink: 1RP

177 return as_py(self._nvjitlink_handle) 1RP

178 else: 

179 return as_py(self._culink_handle) 

180  

181 @classmethod 

182 def which_backend(cls) -> CompilerBackendType: 

183 """Return which linking backend will be used. 

184  

185 Returns :attr:`~CompilerBackendType.NVJITLINK` when the nvJitLink 

186 library is available and meets the minimum version requirement, 

187 otherwise :attr:`~CompilerBackendType.DRIVER`. 

188  

189 .. note:: 

190  

191 Prefer letting :class:`Linker` decide. Query ``which_backend()`` 

192 only when you need to dispatch based on input format (for 

193 example: choose PTX vs. LTOIR before constructing a 

194 ``Linker``). The returned value names an implementation 

195 detail whose support matrix may shift across CTK releases. 

196 """ 

197 return CompilerBackendType.DRIVER if _decide_nvjitlink_or_driver() else CompilerBackendType.NVJITLINK 2db% ' M p z A q B r C u m k D E F G H I v n w x o y J P c d e f g h i a b j

198  

199  

200# ============================================================================= 

201# Supporting classes 

202# ============================================================================= 

203  

204@dataclass 

205class LinkerOptions: 

206 """Customizable options for configuring :class:`Linker`. 

207  

208 Since the linker may choose to use nvJitLink or the driver APIs as the linking backend, 

209 not all options are applicable. When the system's installed nvJitLink is too old (<12.3), 

210 or not installed, the driver APIs (cuLink) will be used instead. 

211  

212 Attributes 

213 ---------- 

214 name : str, optional 

215 Name of the linker. If the linking succeeds, the name is passed down to the generated :class:`ObjectCode`. 

216 arch : str, optional 

217 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or 

218 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture 

219 will be used. 

220 max_register_count : int, optional 

221 Maximum register count. 

222 time : bool, optional 

223 Print timing information to the info log. 

224 Default: False. 

225 verbose : bool, optional 

226 Print verbose messages to the info log. 

227 Default: False. 

228 link_time_optimization : bool, optional 

229 Perform link time optimization. 

230 Default: False. 

231 ptx : bool, optional 

232 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``. 

233 Default: False. 

234 optimization_level : int, optional 

235 Set optimization level. Only 0 and 3 are accepted. 

236 debug : bool, optional 

237 Generate debug information. 

238 Default: False. 

239 lineinfo : bool, optional 

240 Generate line information. 

241 Default: False. 

242 ftz : bool, optional 

243 Flush denormal values to zero. 

244 Default: False. 

245 prec_div : bool, optional 

246 Use precise division. 

247 Default: True. 

248 prec_sqrt : bool, optional 

249 Use precise square root. 

250 Default: True. 

251 fma : bool, optional 

252 Use fast multiply-add. 

253 Default: True. 

254 kernels_used : [str | tuple[str] | list[str]], optional 

255 Pass a kernel or sequence of kernels that are used; any not in the list can be removed. 

256 variables_used : [str | tuple[str] | list[str]], optional 

257 Pass a variable or sequence of variables that are used; any not in the list can be removed. 

258 optimize_unused_variables : bool, optional 

259 Assume that if a variable is not referenced in device code, it can be removed. 

260 Default: False. 

261 ptxas_options : [str | tuple[str] | list[str]], optional 

262 Pass options to PTXAS. 

263 split_compile : int, optional 

264 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split 

265 compilation (default). 

266 Default: 1. 

267 split_compile_extended : int, optional 

268 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value. 

269 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This 

270 option can potentially impact performance of the compiled binary. 

271 Default: 1. 

272 no_cache : bool, optional 

273 Do not cache the intermediate steps of nvJitLink. 

274 Default: False. 

275 """ 

276  

277 name: str | None = "<default linker>" 

278 arch: str | None = None 

279 max_register_count: int | None = None 

280 time: bool | None = None 

281 verbose: bool | None = None 

282 link_time_optimization: bool | None = None 

283 ptx: bool | None = None 

284 optimization_level: int | None = None 

285 debug: bool | None = None 

286 lineinfo: bool | None = None 

287 ftz: bool | None = None 

288 prec_div: bool | None = None 

289 prec_sqrt: bool | None = None 

290 fma: bool | None = None 

291 kernels_used: str | tuple[str] | list[str] | None = None 

292 variables_used: str | tuple[str] | list[str] | None = None 

293 optimize_unused_variables: bool | None = None 

294 ptxas_options: str | tuple[str] | list[str] | None = None 

295 split_compile: int | None = None 

296 split_compile_extended: int | None = None 

297 no_cache: bool | None = None 

298 numba_debug: bool | None = None 

299  

300 def __post_init__(self) -> None: 

301 _lazy_init() 1N9OtRTKQLls!U5SVWXYZ0176432Pcdefghiabj

302 self._name = self.name.encode() 1N9OtRTKQLls!U5SVWXYZ0176432Pcdefghiabj

303  

304 def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> list[bytes] | list[str]: 

305 options = [] 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

306  

307 if self.arch is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

308 options.append(f"-arch={self.arch}") 1OtRpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

309 else: 

310 options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 1M

311 if self.max_register_count is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

312 options.append(f"-maxrregcount={self.max_register_count}") 1zUc

313 if self.time is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

314 options.append("-time") 1rb

315 if self.verbose: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

316 options.append("-verbose") 1p

317 if self.link_time_optimization: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

318 options.append("-lto") 1l

319 if self.ptx: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

320 options.append("-ptx") 1Tl

321 if self.optimization_level is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

322 options.append(f"-O{self.optimization_level}") 1A

323 if self.debug: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

324 options.append("-g") 1qU5da

325 if self.lineinfo: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

326 options.append("-lineinfo") 1B5e

327 if self.ftz is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

328 options.append(f"-ftz={'true' if self.ftz else 'false'}") 1FUf

329 if self.prec_div is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

330 options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 1Gg

331 if self.prec_sqrt is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

332 options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 1Hh

333 if self.fma is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

334 options.append(f"-fma={'true' if self.fma else 'false'}") 1Ii

335 if self.kernels_used is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

336 if isinstance(self.kernels_used, str): 1vnw

337 options.append(f"-kernels-used={self.kernels_used}") 1v

338 elif isinstance(self.kernels_used, list): 1nw

339 for kernel in self.kernels_used: 1n

340 options.append(f"-kernels-used={kernel}") 1n

341 if self.variables_used is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

342 if isinstance(self.variables_used, str): 1xoy

343 options.append(f"-variables-used={self.variables_used}") 1x

344 elif isinstance(self.variables_used, list): 1oy

345 for variable in self.variables_used: 1o

346 options.append(f"-variables-used={variable}") 1o

347 if self.optimize_unused_variables is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

348 options.append("-optimize-unused-variables") 1C

349 if self.ptxas_options is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

350 if isinstance(self.ptxas_options, str): 1umk

351 options.append(f"-Xptxas={self.ptxas_options}") 1u

352 elif is_sequence(self.ptxas_options): 1mk

353 for opt in self.ptxas_options: 1mk

354 options.append(f"-Xptxas={opt}") 1mk

355 if self.split_compile is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

356 options.append(f"-split-compile={self.split_compile}") 1Dj

357 if self.split_compile_extended is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

358 options.append(f"-split-compile-extended={self.split_compile_extended}") 1E

359 if self.no_cache is True: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

360 options.append("-no-cache") 1J

361  

362 if as_bytes: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj

363 return [o.encode() for o in options] 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsUPcdefghiabj

364 else: 

365 return options 15

366  

367 def _prepare_driver_options(self) -> tuple[list[object], list[object]]: 

368 formatted_options = [] 1SVWXYZ0176432

369 option_keys = [] 1SVWXYZ0176432

370  

371 # allocate a fixed-sized buffer for each info/error log 

372 size = 4194304 1SVWXYZ0176432

373 formatted_options.extend((bytearray(size), size, bytearray(size), size)) 1SVWXYZ0176432

374 option_keys.extend( 1SVWXYZ0176432

375 ( 

376 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER, 1SVWXYZ0176432

377 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, 1SVWXYZ0176432

378 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER, 1SVWXYZ0176432

379 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, 1SVWXYZ0176432

380 ) 

381 ) 

382  

383 if self.arch is not None: 1SVWXYZ0176432

384 arch = self.arch.split("_")[-1].upper() 1S

385 formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}")) 1S

386 option_keys.append(_driver.CUjit_option.CU_JIT_TARGET) 1S

387 if self.max_register_count is not None: 1SVWXYZ0176432

388 formatted_options.append(self.max_register_count) 1S

389 option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS) 1S

390 if self.time is not None: 1SVWXYZ0176432

391 raise ValueError("time option is not supported by the driver API") 17

392 if self.verbose: 1SVWXYZ016432

393 formatted_options.append(1) 1S

394 option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE) 1S

395 if self.link_time_optimization: 1SVWXYZ016432

396 formatted_options.append(1) 1S

397 option_keys.append(_driver.CUjit_option.CU_JIT_LTO) 1S

398 if self.ptx: 1SVWXYZ016432

399 raise ValueError("ptx option is not supported by the driver API") 16

400 if self.optimization_level is not None: 1SVWXYZ01432

401 formatted_options.append(self.optimization_level) 1S

402 option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL) 1S

403 if self.debug: 1SVWXYZ01432

404 formatted_options.append(1) 1S

405 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO) 1S

406 if self.lineinfo: 1SVWXYZ01432

407 formatted_options.append(1) 1S

408 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO) 1S

409 if self.ftz is not None: 1SVWXYZ01432

410 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1V

411 if self.prec_div is not None: 1SVWXYZ01432

412 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1W

413 if self.prec_sqrt is not None: 1SVWXYZ01432

414 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1X

415 if self.fma is not None: 1SVWXYZ01432

416 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1Y

417 if self.kernels_used is not None: 1SVWXYZ01432

418 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1Z

419 if self.variables_used is not None: 1SVWXYZ01432

420 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 10

421 if self.optimize_unused_variables is not None: 1SVWXYZ01432

422 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3) 11

423 if self.ptxas_options is not None: 1SVWXYZ01432

424 raise ValueError("ptxas_options option is not supported by the driver API") 14

425 if self.split_compile is not None: 1SVWXYZ0132

426 raise ValueError("split_compile option is not supported by the driver API") 13

427 if self.split_compile_extended is not None: 1SVWXYZ012

428 raise ValueError("split_compile_extended option is not supported by the driver API") 12

429 if self.no_cache is True: 1SVWXYZ01

430 formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE) 1S

431 option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE) 1S

432  

433 return formatted_options, option_keys 1SVWXYZ01

434  

435 def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]: 

436 """Convert linker options to bytes format for the nvjitlink backend. 

437  

438 Parameters 

439 ---------- 

440 backend : str, optional 

441 The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink". 

442  

443 Returns 

444 ------- 

445 list[bytes] 

446 List of option strings encoded as bytes. 

447  

448 Raises 

449 ------ 

450 ValueError 

451 If an unsupported backend is specified. 

452 RuntimeError 

453 If nvJitLink backend is not available. 

454 """ 

455 backend = backend.lower() 19!U

456 if backend != "nvjitlink": 19!U

457 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") 1!

458 if not _use_nvjitlink_backend: 19U

459 raise RuntimeError("nvJitLink backend is not available") 19

460 return self._prepare_nvjitlink_options(as_bytes=True) 1U

461  

462  

463# ============================================================================= 

464# Private implementation: cdef inline helpers 

465# ============================================================================= 

466  

467cdef inline int Linker_init(Linker self, tuple object_codes, object options) except -1: 

468 """Initialize a Linker instance.""" 

469 if len(object_codes) == 0: 1$OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

470 raise ValueError("At least one ObjectCode object must be provided") 1$

471  

472 cdef cynvjitlink.nvJitLinkHandle c_raw_nvjitlink 

473 cdef cydriver.CUlinkState c_raw_culink 

474 cdef Py_ssize_t c_num_opts, i 

475 cdef vector[const_char_ptr] c_str_opts 

476 cdef vector[cydriver.CUjit_option] c_jit_keys 

477 cdef vector[void_ptr] c_jit_values 

478  

479 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

480  

481 if _use_nvjitlink_backend: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

482 self._use_nvjitlink = True 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

483 options_bytes = options._prepare_nvjitlink_options(as_bytes=True) 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

484 c_num_opts = len(options_bytes) 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

485 c_str_opts.resize(c_num_opts) 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

486 for i in range(c_num_opts): 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

487 c_str_opts[i] = <const char*>(<bytes>options_bytes[i]) 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

488 with nogil: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

489 HANDLE_RETURN_NVJITLINK(NULL, cynvjitlink.nvJitLinkCreate( 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj

490 &c_raw_nvjitlink, <uint32_t>c_num_opts, c_str_opts.data())) 

491 self._nvjitlink_handle = create_nvjitlink_handle(c_raw_nvjitlink) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

492 else: 

493 self._use_nvjitlink = False 

494 formatted_options, option_keys = options._prepare_driver_options() 

495 # Keep the formatted_options list alive: it contains bytearrays that 

496 # the driver writes into via raw pointers during linking operations. 

497 self._drv_log_bufs = formatted_options 

498 c_num_opts = len(option_keys) 

499 c_jit_keys.resize(c_num_opts) 

500 c_jit_values.resize(c_num_opts) 

501 for i in range(c_num_opts): 

502 c_jit_keys[i] = <cydriver.CUjit_option><int>option_keys[i] 

503 val = formatted_options[i] 

504 if isinstance(val, bytearray): 

505 c_jit_values[i] = <void*>PyByteArray_AS_STRING(val) 

506 else: 

507 c_jit_values[i] = <void*><intptr_t>int(val) 

508 try: 

509 with nogil: 

510 HANDLE_RETURN(cydriver.cuLinkCreate( 

511 <unsigned int>c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink)) 

512 except CUDAError as e: 

513 Linker_annotate_error_log(self, e) 

514 raise 

515 self._culink_handle = create_culink_handle(c_raw_culink) 

516  

517 for code in object_codes: 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

518 assert_type(code, ObjectCode) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

519 Linker_add_code_object(self, code) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

520 return 0 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

521  

522  

523cdef inline void Linker_add_code_object(Linker self, object object_code) except *: 

524 """Add a single ObjectCode to the linker.""" 

525 data = object_code.code 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

526 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h 

527 cdef cydriver.CUlinkState c_culink_state 

528 cdef cynvjitlink.nvJitLinkInputType c_nv_input_type 

529 cdef cydriver.CUjitInputType c_drv_input_type 

530 cdef const char* c_data_ptr 

531 cdef size_t c_data_size 

532 cdef const char* c_name_ptr 

533 cdef const char* c_file_ptr 

534  

535 name_bytes = f"{object_code.name}".encode() 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

536 c_name_ptr = <const char*>name_bytes 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

537  

538 input_types = _nvjitlink_input_types if self._use_nvjitlink else _driver_input_types 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

539 py_input_type = input_types.get(object_code.code_type) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

540 if py_input_type is None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

541 raise ValueError(f"Unknown code_type associated with ObjectCode: {object_code.code_type}") 

542  

543 if self._use_nvjitlink: 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

544 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

545 c_nv_input_type = <cynvjitlink.nvJitLinkInputType><int>py_input_type 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

546 if isinstance(data, bytes): 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

547 c_data_ptr = <const char*>(<bytes>data) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

548 c_data_size = len(data) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

549 with nogil: 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

550 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddData( 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj

551 c_nvjitlink_h, c_nv_input_type, <const void*>c_data_ptr, c_data_size, c_name_ptr)) 

552 elif isinstance(data, str): 

553 file_bytes = data.encode() 

554 c_file_ptr = <const char*>file_bytes 

555 with nogil: 

556 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddFile( 

557 c_nvjitlink_h, c_nv_input_type, c_file_ptr)) 

558 else: 

559 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") 

560 else: 

561 c_culink_state = as_cu(self._culink_handle) 

562 c_drv_input_type = <cydriver.CUjitInputType><int>py_input_type 

563 try: 

564 if isinstance(data, bytes): 

565 c_data_ptr = <const char*>(<bytes>data) 

566 c_data_size = len(data) 

567 with nogil: 

568 HANDLE_RETURN(cydriver.cuLinkAddData( 

569 c_culink_state, c_drv_input_type, <void*>c_data_ptr, c_data_size, c_name_ptr, 

570 0, NULL, NULL)) 

571 elif isinstance(data, str): 

572 file_bytes = data.encode() 

573 c_file_ptr = <const char*>file_bytes 

574 with nogil: 

575 HANDLE_RETURN(cydriver.cuLinkAddFile( 

576 c_culink_state, c_drv_input_type, c_file_ptr, 0, NULL, NULL)) 

577 else: 

578 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") 

579 except CUDAError as e: 

580 Linker_annotate_error_log(self, e) 

581 raise 

582  

583  

584cdef inline object Linker_link(Linker self, str target_type): 

585 """Complete linking and return the result as ObjectCode.""" 

586 if target_type not in ("cubin", "ptx"): 1OtMpzAqBrCumkDEFGHIvnwxoyJKQLlscdefghiabj

587 raise ValueError(f"Unsupported target type: {target_type}") 1Q

588  

589 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h 

590 cdef cydriver.CUlinkState c_culink_state 

591 cdef size_t c_output_size = 0 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

592 cdef char* c_code_ptr 

593 cdef void* c_cubin_out = NULL 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

594  

595 if self._use_nvjitlink: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

596 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

597 with nogil: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

598 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkComplete(c_nvjitlink_h)) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

599 if target_type == "cubin": 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

600 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj

601 cynvjitlink.nvJitLinkGetLinkedCubinSize(c_nvjitlink_h, &c_output_size)) 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj

602 code = bytearray(c_output_size) 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj

603 c_code_ptr = <char*>(<bytearray>code) 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj

604 with nogil: 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj

605 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj

606 cynvjitlink.nvJitLinkGetLinkedCubin(c_nvjitlink_h, c_code_ptr)) 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj

607 else: 

608 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1l

609 cynvjitlink.nvJitLinkGetLinkedPtxSize(c_nvjitlink_h, &c_output_size)) 1l

610 code = bytearray(c_output_size) 1l

611 c_code_ptr = <char*>(<bytearray>code) 1l

612 with nogil: 1l

613 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1l

614 cynvjitlink.nvJitLinkGetLinkedPtx(c_nvjitlink_h, c_code_ptr)) 1l

615 else: 

616 c_culink_state = as_cu(self._culink_handle) 

617 try: 

618 with nogil: 

619 HANDLE_RETURN(cydriver.cuLinkComplete(c_culink_state, &c_cubin_out, &c_output_size)) 

620 except CUDAError as e: 

621 Linker_annotate_error_log(self, e) 

622 raise 

623 code = (<char*>c_cubin_out)[:c_output_size] 

624  

625 # Linking is complete; cache the decoded log strings and release 

626 # the driver's raw bytearray buffers (no longer written to). 

627 self._info_log = self.get_info_log() 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

628 self._error_log = self.get_error_log() 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

629 self._drv_log_bufs = None 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

630  

631 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj

632  

633  

634cdef inline void Linker_annotate_error_log(Linker self, object e): 

635 """Annotate a CUDAError with the driver linker error log.""" 

636 error_log = self.get_error_log() 

637 if error_log: 

638 e.args = (e.args[0] + f"\nLinker error log: {error_log}", *e.args[1:]) 

639  

640  

641# ============================================================================= 

642# Private implementation: module-level state and initialization 

643# ============================================================================= 

644  

645# TODO: revisit this treatment for py313t builds 

646_driver = None # populated if nvJitLink cannot be used 

647_inited = False 

648_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver() 

649  

650# Input type mappings populated by _lazy_init() with C-level enum ints. 

651_nvjitlink_input_types = None 

652_driver_input_types = None 

653  

654  

655def _nvjitlink_has_version_symbol(nvjitlink) -> bool: 

656 # This condition is equivalent to testing for version >= 12.3 

657 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion")) 

658  

659  

660# Note: this function is reused in the tests 

661def _decide_nvjitlink_or_driver() -> bool: 

662 """Return True if falling back to the cuLink* driver APIs.""" 

663 global _driver, _use_nvjitlink_backend 

664 if _use_nvjitlink_backend is not None: 2N % ' M p z A q B r C u m k D E F G H I v n w x o y J 8 # P c d e f g h i a b j ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcb

665 return not _use_nvjitlink_backend 2N % ' M p z A q B r C u m k D E F G H I v n w x o y J P c d e f g h i a b j ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcb

666  

667 warn_txt_common = ( 

668 "the driver APIs will be used instead, which do not support" 1N8#

669 " minor version compatibility or linking LTO IRs." 

670 " For best results, consider upgrading to a recent version of" 

671 ) 

672  

673 nvjitlink_module = _optional_cuda_import( 1N8#

674 "cuda.bindings.nvjitlink", 

675 probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load 1N8#

676 ) 

677 if nvjitlink_module is None: 1N8

678 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 18

679 else: 

680 from cuda.bindings._internal import nvjitlink 

681  

682 if _nvjitlink_has_version_symbol(nvjitlink): 

683 _use_nvjitlink_backend = True 

684 return False # Use nvjitlink 

685 warn_txt = ( 

686 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)." 

687 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink." 

688 ) 

689  

690 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 18

691 _use_nvjitlink_backend = False 18

692 _driver = driver 18

693 return True 18

694  

695  

696def _lazy_init() -> None: 

697 global _inited, _nvjitlink_input_types, _driver_input_types 

698 if _inited: 1N9OtRTKQLls!U5SVWXYZ0176432Pcdefghiabj

699 return 1N9OtRTKQLls!U5SVWXYZ0176432Pcdefghiabj

700  

701 _decide_nvjitlink_or_driver() 

702 if _use_nvjitlink_backend: 

703 _nvjitlink_input_types = { 

704 "ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX, 

705 "cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN, 

706 "fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN, 

707 "ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR, 

708 "object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT, 

709 "library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY, 

710 } 

711 else: 

712 _driver_input_types = { 

713 "ptx": <int>cydriver.CU_JIT_INPUT_PTX, 

714 "cubin": <int>cydriver.CU_JIT_INPUT_CUBIN, 

715 "fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY, 

716 "object": <int>cydriver.CU_JIT_INPUT_OBJECT, 

717 "library": <int>cydriver.CU_JIT_INPUT_LIBRARY, 

718 } 

719 _inited = True