Coverage for cuda / core / _linker.pyx: 64.44%

360 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-29 01:27 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4"""Linking machinery for combining object codes. 

5  

6This module provides :class:`Linker` for linking one or more 

7:class:`~cuda.core.ObjectCode` objects, with :class:`LinkerOptions` for 

8configuration. 

9""" 

10  

11from __future__ import annotations 

12  

13from cpython.bytearray cimport PyByteArray_AS_STRING 

14from libc.stdint cimport intptr_t, uint32_t 

15from libcpp.vector cimport vector 

16from cuda.bindings cimport cydriver 

17from cuda.bindings cimport cynvjitlink 

18  

19from ._resource_handles cimport ( 

20 as_cu, 

21 as_py, 

22 create_culink_handle, 

23 create_nvjitlink_handle, 

24) 

25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, HANDLE_RETURN_NVJITLINK 

26  

27import sys 

28from dataclasses import dataclass 

29from typing import Union 

30from warnings import warn 

31  

32from cuda.pathfinder._optional_cuda_import import _optional_cuda_import 

33from cuda.core._device import Device 

34from cuda.core._module import ObjectCode 

35from cuda.core._utils.clear_error_support import assert_type 

36from cuda.core._utils.cuda_utils import ( 

37 CUDAError, 

38 check_or_create_options, 

39 driver, 

40 is_sequence, 

41) 

42  

43ctypedef const char* const_char_ptr 

44ctypedef void* void_ptr 

45  

46__all__ = ["Linker", "LinkerOptions"] 

47  

48LinkerHandleT = Union["cuda.bindings.nvjitlink.nvJitLinkHandle", "cuda.bindings.driver.CUlinkState"] 

49  

50  

51# ============================================================================= 

52# Principal class 

53# ============================================================================= 

54  

55cdef class Linker: 

56 """Represent a linking machinery to link one or more object codes into 

57 :class:`~cuda.core.ObjectCode`. 

58  

59 This object provides a unified interface to multiple underlying 

60 linker libraries (such as nvJitLink or cuLink* from the CUDA driver). 

61  

62 Parameters 

63 ---------- 

64 object_codes : :class:`~cuda.core.ObjectCode` 

65 One or more ObjectCode objects to be linked. 

66 options : :class:`LinkerOptions`, optional 

67 Options for the linker. If not provided, default options will be used. 

68 """ 

69  

70 def __init__(self, *object_codes: ObjectCode, options: "LinkerOptions" = None): 

71 Linker_init(self, object_codes, options) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

72  

73 def link(self, target_type) -> ObjectCode: 

74 """Link the provided object codes into a single output of the specified target type. 

75  

76 Parameters 

77 ---------- 

78 target_type : str 

79 The type of the target output. Must be either "cubin" or "ptx". 

80  

81 Returns 

82 ------- 

83 :class:`~cuda.core.ObjectCode` 

84 The linked object code of the specified target type. 

85  

86 .. note:: 

87  

88 Ensure that input object codes were compiled with appropriate 

89 flags for linking (e.g., relocatable device code enabled). 

90 """ 

91 return Linker_link(self, target_type) 1NmLrABsCtDnkoEFGHIJwpxyqzKuPvcldaefghibj

92  

93 def get_error_log(self) -> str: 

94 """Get the error log generated by the linker. 

95  

96 Returns 

97 ------- 

98 str 

99 The error log. 

100 """ 

101 # After link(), the decoded log is cached here. 

102 if self._error_log is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

103 return self._error_log 1l

104 cdef cynvjitlink.nvJitLinkHandle c_h 

105 cdef size_t c_log_size = 0 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

106 cdef char* c_log_ptr 

107 if self._use_nvjitlink: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

108 c_h = as_cu(self._nvjitlink_handle) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

109 cynvjitlink.nvJitLinkGetErrorLogSize(c_h, &c_log_size) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

110 log = bytearray(c_log_size) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

111 if c_log_size > 0: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

112 c_log_ptr = <char*>(<bytearray>log) 1N

113 cynvjitlink.nvJitLinkGetErrorLog(c_h, c_log_ptr) 1N

114 return log.decode("utf-8", errors="backslashreplace") 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

115 else: 

116 return (<bytearray>self._drv_log_bufs[2]).decode( 

117 "utf-8", errors="backslashreplace").rstrip('\x00') 

118  

119 def get_info_log(self) -> str: 

120 """Get the info log generated by the linker. 

121  

122 Returns 

123 ------- 

124 str 

125 The info log. 

126 """ 

127 # After link(), the decoded log is cached here. 

128 if self._info_log is not None: 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

129 return self._info_log 1ml

130 cdef cynvjitlink.nvJitLinkHandle c_h 

131 cdef size_t c_log_size = 0 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

132 cdef char* c_log_ptr 

133 if self._use_nvjitlink: 1MmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

134 c_h = as_cu(self._nvjitlink_handle) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

135 cynvjitlink.nvJitLinkGetInfoLogSize(c_h, &c_log_size) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

136 log = bytearray(c_log_size) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

137 if c_log_size > 0: 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

138 c_log_ptr = <char*>(<bytearray>log) 1rstnkcab

139 cynvjitlink.nvJitLinkGetInfoLog(c_h, c_log_ptr) 1rstnkcab

140 return log.decode("utf-8", errors="backslashreplace") 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

141 else: 

142 return (<bytearray>self._drv_log_bufs[0]).decode( 

143 "utf-8", errors="backslashreplace").rstrip('\x00') 

144  

145 def close(self): 

146 """Destroy this linker.""" 

147 if self._use_nvjitlink: 1Odaefghibj

148 self._nvjitlink_handle.reset() 1Odaefghibj

149 else: 

150 self._culink_handle.reset() 

151  

152 @property 

153 def handle(self) -> LinkerHandleT: 

154 """Return the underlying handle object. 

155  

156 .. note:: 

157  

158 The type of the returned object depends on the backend. 

159  

160 .. caution:: 

161  

162 This handle is a Python object. To get the memory address of the underlying C 

163 handle, call ``int(Linker.handle)``. 

164 """ 

165 if self._use_nvjitlink: 1QO

166 return as_py(self._nvjitlink_handle) 1QO

167 else: 

168 return as_py(self._culink_handle) 

169  

170 @property 

171 def backend(self) -> str: 

172 """Return this Linker instance's underlying backend.""" 

173 return "nvJitLink" if self._use_nvjitlink else "driver" 1LrABsCtDnkoEFGHIJwpxyqzKOdaefghibj

174  

175  

176# ============================================================================= 

177# Supporting classes 

178# ============================================================================= 

179  

180@dataclass 

181class LinkerOptions: 

182 """Customizable options for configuring :class:`Linker`. 

183  

184 Since the linker may choose to use nvJitLink or the driver APIs as the linking backend, 

185 not all options are applicable. When the system's installed nvJitLink is too old (<12.3), 

186 or not installed, the driver APIs (cuLink) will be used instead. 

187  

188 Attributes 

189 ---------- 

190 name : str, optional 

191 Name of the linker. If the linking succeeds, the name is passed down to the generated `ObjectCode`. 

192 arch : str, optional 

193 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or 

194 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture 

195 will be used. 

196 max_register_count : int, optional 

197 Maximum register count. 

198 time : bool, optional 

199 Print timing information to the info log. 

200 Default: False. 

201 verbose : bool, optional 

202 Print verbose messages to the info log. 

203 Default: False. 

204 link_time_optimization : bool, optional 

205 Perform link time optimization. 

206 Default: False. 

207 ptx : bool, optional 

208 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``. 

209 Default: False. 

210 optimization_level : int, optional 

211 Set optimization level. Only 0 and 3 are accepted. 

212 debug : bool, optional 

213 Generate debug information. 

214 Default: False. 

215 lineinfo : bool, optional 

216 Generate line information. 

217 Default: False. 

218 ftz : bool, optional 

219 Flush denormal values to zero. 

220 Default: False. 

221 prec_div : bool, optional 

222 Use precise division. 

223 Default: True. 

224 prec_sqrt : bool, optional 

225 Use precise square root. 

226 Default: True. 

227 fma : bool, optional 

228 Use fast multiply-add. 

229 Default: True. 

230 kernels_used : [str | tuple[str] | list[str]], optional 

231 Pass a kernel or sequence of kernels that are used; any not in the list can be removed. 

232 variables_used : [str | tuple[str] | list[str]], optional 

233 Pass a variable or sequence of variables that are used; any not in the list can be removed. 

234 optimize_unused_variables : bool, optional 

235 Assume that if a variable is not referenced in device code, it can be removed. 

236 Default: False. 

237 ptxas_options : [str | tuple[str] | list[str]], optional 

238 Pass options to PTXAS. 

239 split_compile : int, optional 

240 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split 

241 compilation (default). 

242 Default: 1. 

243 split_compile_extended : int, optional 

244 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value. 

245 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This 

246 option can potentially impact performance of the compiled binary. 

247 Default: 1. 

248 no_cache : bool, optional 

249 Do not cache the intermediate steps of nvJitLink. 

250 Default: False. 

251 """ 

252  

253 name: str | None = "<default linker>" 

254 arch: str | None = None 

255 max_register_count: int | None = None 

256 time: bool | None = None 

257 verbose: bool | None = None 

258 link_time_optimization: bool | None = None 

259 ptx: bool | None = None 

260 optimization_level: int | None = None 

261 debug: bool | None = None 

262 lineinfo: bool | None = None 

263 ftz: bool | None = None 

264 prec_div: bool | None = None 

265 prec_sqrt: bool | None = None 

266 fma: bool | None = None 

267 kernels_used: str | tuple[str] | list[str] | None = None 

268 variables_used: str | tuple[str] | list[str] | None = None 

269 optimize_unused_variables: bool | None = None 

270 ptxas_options: str | tuple[str] | list[str] | None = None 

271 split_compile: int | None = None 

272 split_compile_extended: int | None = None 

273 no_cache: bool | None = None 

274  

275 def __post_init__(self): 

276 _lazy_init() 1MNmQRuPvclVSTOdaefghibj

277 self._name = self.name.encode() 1MNmQRuPvclVSTOdaefghibj

278  

279 def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> list[bytes] | list[str]: 

280 options = [] 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

281  

282 if self.arch is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

283 options.append(f"-arch={self.arch}") 1NmQrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

284 else: 

285 options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 1L

286 if self.max_register_count is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

287 options.append(f"-maxrregcount={self.max_register_count}") 1ASd

288 if self.time is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

289 options.append("-time") 1tb

290 if self.verbose: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

291 options.append("-verbose") 1r

292 if self.link_time_optimization: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

293 options.append("-lto") 1c

294 if self.ptx: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

295 options.append("-ptx") 1Rc

296 if self.optimization_level is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

297 options.append(f"-O{self.optimization_level}") 1B

298 if self.debug: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

299 options.append("-g") 1sSTa

300 if self.lineinfo: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

301 options.append("-lineinfo") 1CTe

302 if self.ftz is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

303 options.append(f"-ftz={'true' if self.ftz else 'false'}") 1GSf

304 if self.prec_div is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

305 options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 1Hg

306 if self.prec_sqrt is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

307 options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 1Ih

308 if self.fma is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

309 options.append(f"-fma={'true' if self.fma else 'false'}") 1Ji

310 if self.kernels_used is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

311 if isinstance(self.kernels_used, str): 1wpx

312 options.append(f"-kernels-used={self.kernels_used}") 1w

313 elif isinstance(self.kernels_used, list): 1px

314 for kernel in self.kernels_used: 1p

315 options.append(f"-kernels-used={kernel}") 1p

316 if self.variables_used is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

317 if isinstance(self.variables_used, str): 1yqz

318 options.append(f"-variables-used={self.variables_used}") 1y

319 elif isinstance(self.variables_used, list): 1qz

320 for variable in self.variables_used: 1q

321 options.append(f"-variables-used={variable}") 1q

322 if self.optimize_unused_variables is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

323 options.append("-optimize-unused-variables") 1D

324 if self.ptxas_options is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

325 if isinstance(self.ptxas_options, str): 1nko

326 options.append(f"-Xptxas={self.ptxas_options}") 1n

327 elif is_sequence(self.ptxas_options): 1ko

328 for opt in self.ptxas_options: 1ko

329 options.append(f"-Xptxas={opt}") 1ko

330 if self.split_compile is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

331 options.append(f"-split-compile={self.split_compile}") 1Ej

332 if self.split_compile_extended is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

333 options.append(f"-split-compile-extended={self.split_compile_extended}") 1F

334 if self.no_cache is True: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

335 options.append("-no-cache") 1K

336  

337 if as_bytes: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj

338 return [o.encode() for o in options] 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSOdaefghibj

339 else: 

340 return options 1T

341  

342 def _prepare_driver_options(self) -> tuple[list, list]: 

343 formatted_options = [] 

344 option_keys = [] 

345  

346 # allocate a fixed-sized buffer for each info/error log 

347 size = 4194304 

348 formatted_options.extend((bytearray(size), size, bytearray(size), size)) 

349 option_keys.extend( 

350 ( 

351 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER, 

352 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, 

353 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER, 

354 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, 

355 ) 

356 ) 

357  

358 if self.arch is not None: 

359 arch = self.arch.split("_")[-1].upper() 

360 formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}")) 

361 option_keys.append(_driver.CUjit_option.CU_JIT_TARGET) 

362 if self.max_register_count is not None: 

363 formatted_options.append(self.max_register_count) 

364 option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS) 

365 if self.time is not None: 

366 raise ValueError("time option is not supported by the driver API") 

367 if self.verbose: 

368 formatted_options.append(1) 

369 option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE) 

370 if self.link_time_optimization: 

371 formatted_options.append(1) 

372 option_keys.append(_driver.CUjit_option.CU_JIT_LTO) 

373 if self.ptx: 

374 raise ValueError("ptx option is not supported by the driver API") 

375 if self.optimization_level is not None: 

376 formatted_options.append(self.optimization_level) 

377 option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL) 

378 if self.debug: 

379 formatted_options.append(1) 

380 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO) 

381 if self.lineinfo: 

382 formatted_options.append(1) 

383 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO) 

384 if self.ftz is not None: 

385 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

386 if self.prec_div is not None: 

387 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

388 if self.prec_sqrt is not None: 

389 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

390 if self.fma is not None: 

391 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

392 if self.kernels_used is not None: 

393 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

394 if self.variables_used is not None: 

395 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

396 if self.optimize_unused_variables is not None: 

397 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3) 

398 if self.ptxas_options is not None: 

399 raise ValueError("ptxas_options option is not supported by the driver API") 

400 if self.split_compile is not None: 

401 raise ValueError("split_compile option is not supported by the driver API") 

402 if self.split_compile_extended is not None: 

403 raise ValueError("split_compile_extended option is not supported by the driver API") 

404 if self.no_cache is True: 

405 formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE) 

406 option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE) 

407  

408 return formatted_options, option_keys 

409  

410 def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]: 

411 """Convert linker options to bytes format for the nvjitlink backend. 

412  

413 Parameters 

414 ---------- 

415 backend : str, optional 

416 The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink". 

417  

418 Returns 

419 ------- 

420 list[bytes] 

421 List of option strings encoded as bytes. 

422  

423 Raises 

424 ------ 

425 ValueError 

426 If an unsupported backend is specified. 

427 RuntimeError 

428 If nvJitLink backend is not available. 

429 """ 

430 backend = backend.lower() 1VS

431 if backend != "nvjitlink": 1VS

432 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") 1V

433 if not _use_nvjitlink_backend: 1S

434 raise RuntimeError("nvJitLink backend is not available") 

435 return self._prepare_nvjitlink_options(as_bytes=True) 1S

436  

437  

438# ============================================================================= 

439# Private implementation: cdef inline helpers 

440# ============================================================================= 

441  

442cdef inline int Linker_init(Linker self, tuple object_codes, object options) except -1: 

443 """Initialize a Linker instance.""" 

444 if len(object_codes) == 0: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

445 raise ValueError("At least one ObjectCode object must be provided") 

446  

447 cdef cynvjitlink.nvJitLinkHandle c_raw_nvjitlink 

448 cdef cydriver.CUlinkState c_raw_culink 

449 cdef Py_ssize_t c_num_opts, i 

450 cdef vector[const_char_ptr] c_str_opts 

451 cdef vector[cydriver.CUjit_option] c_jit_keys 

452 cdef vector[void_ptr] c_jit_values 

453  

454 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

455  

456 if _use_nvjitlink_backend: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

457 self._use_nvjitlink = True 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

458 options_bytes = options._prepare_nvjitlink_options(as_bytes=True) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

459 c_num_opts = len(options_bytes) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

460 c_str_opts.resize(c_num_opts) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

461 for i in range(c_num_opts): 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

462 c_str_opts[i] = <const char*>(<bytes>options_bytes[i]) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

463 with nogil: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

464 HANDLE_RETURN_NVJITLINK(NULL, cynvjitlink.nvJitLinkCreate( 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj

465 &c_raw_nvjitlink, <uint32_t>c_num_opts, c_str_opts.data())) 

466 self._nvjitlink_handle = create_nvjitlink_handle(c_raw_nvjitlink) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

467 else: 

468 self._use_nvjitlink = False 

469 formatted_options, option_keys = options._prepare_driver_options() 

470 # Keep the formatted_options list alive: it contains bytearrays that 

471 # the driver writes into via raw pointers during linking operations. 

472 self._drv_log_bufs = formatted_options 

473 c_num_opts = len(option_keys) 

474 c_jit_keys.resize(c_num_opts) 

475 c_jit_values.resize(c_num_opts) 

476 for i in range(c_num_opts): 

477 c_jit_keys[i] = <cydriver.CUjit_option><int>option_keys[i] 

478 val = formatted_options[i] 

479 if isinstance(val, bytearray): 

480 c_jit_values[i] = <void*>PyByteArray_AS_STRING(val) 

481 else: 

482 c_jit_values[i] = <void*><intptr_t>int(val) 

483 try: 

484 with nogil: 

485 HANDLE_RETURN(cydriver.cuLinkCreate( 

486 <unsigned int>c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink)) 

487 except CUDAError as e: 

488 Linker_annotate_error_log(self, e) 

489 raise 

490 self._culink_handle = create_culink_handle(c_raw_culink) 

491  

492 for code in object_codes: 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

493 assert_type(code, ObjectCode) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

494 Linker_add_code_object(self, code) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

495 return 0 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

496  

497  

498cdef inline void Linker_add_code_object(Linker self, object object_code) except *: 

499 """Add a single ObjectCode to the linker.""" 

500 data = object_code.code 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

501 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h 

502 cdef cydriver.CUlinkState c_culink_state 

503 cdef cynvjitlink.nvJitLinkInputType c_nv_input_type 

504 cdef cydriver.CUjitInputType c_drv_input_type 

505 cdef const char* c_data_ptr 

506 cdef size_t c_data_size 

507 cdef const char* c_name_ptr 

508 cdef const char* c_file_ptr 

509  

510 name_bytes = f"{object_code.name}".encode() 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

511 c_name_ptr = <const char*>name_bytes 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

512  

513 input_types = _nvjitlink_input_types if self._use_nvjitlink else _driver_input_types 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

514 py_input_type = input_types.get(object_code.code_type) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

515 if py_input_type is None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

516 raise ValueError(f"Unknown code_type associated with ObjectCode: {object_code.code_type}") 

517  

518 if self._use_nvjitlink: 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

519 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

520 c_nv_input_type = <cynvjitlink.nvJitLinkInputType><int>py_input_type 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

521 if isinstance(data, bytes): 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

522 c_data_ptr = <const char*>(<bytes>data) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

523 c_data_size = len(data) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

524 with nogil: 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

525 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddData( 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj

526 c_nvjitlink_h, c_nv_input_type, <const void*>c_data_ptr, c_data_size, c_name_ptr)) 

527 elif isinstance(data, str): 

528 file_bytes = data.encode() 

529 c_file_ptr = <const char*>file_bytes 

530 with nogil: 

531 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddFile( 

532 c_nvjitlink_h, c_nv_input_type, c_file_ptr)) 

533 else: 

534 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") 

535 else: 

536 c_culink_state = as_cu(self._culink_handle) 

537 c_drv_input_type = <cydriver.CUjitInputType><int>py_input_type 

538 try: 

539 if isinstance(data, bytes): 

540 c_data_ptr = <const char*>(<bytes>data) 

541 c_data_size = len(data) 

542 with nogil: 

543 HANDLE_RETURN(cydriver.cuLinkAddData( 

544 c_culink_state, c_drv_input_type, <void*>c_data_ptr, c_data_size, c_name_ptr, 

545 0, NULL, NULL)) 

546 elif isinstance(data, str): 

547 file_bytes = data.encode() 

548 c_file_ptr = <const char*>file_bytes 

549 with nogil: 

550 HANDLE_RETURN(cydriver.cuLinkAddFile( 

551 c_culink_state, c_drv_input_type, c_file_ptr, 0, NULL, NULL)) 

552 else: 

553 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}") 

554 except CUDAError as e: 

555 Linker_annotate_error_log(self, e) 

556 raise 

557  

558  

559cdef inline object Linker_link(Linker self, str target_type): 

560 """Complete linking and return the result as ObjectCode.""" 

561 if target_type not in ("cubin", "ptx"): 1NmLrABsCtDnkoEFGHIJwpxyqzKuPvcldaefghibj

562 raise ValueError(f"Unsupported target type: {target_type}") 1P

563  

564 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h 

565 cdef cydriver.CUlinkState c_culink_state 

566 cdef size_t c_output_size = 0 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

567 cdef char* c_code_ptr 

568 cdef void* c_cubin_out = NULL 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

569  

570 if self._use_nvjitlink: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

571 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

572 with nogil: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

573 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkComplete(c_nvjitlink_h)) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

574 if target_type == "cubin": 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

575 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj

576 cynvjitlink.nvJitLinkGetLinkedCubinSize(c_nvjitlink_h, &c_output_size)) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj

577 code = bytearray(c_output_size) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj

578 c_code_ptr = <char*>(<bytearray>code) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj

579 with nogil: 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj

580 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj

581 cynvjitlink.nvJitLinkGetLinkedCubin(c_nvjitlink_h, c_code_ptr)) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj

582 else: 

583 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1c

584 cynvjitlink.nvJitLinkGetLinkedPtxSize(c_nvjitlink_h, &c_output_size)) 1c

585 code = bytearray(c_output_size) 1c

586 c_code_ptr = <char*>(<bytearray>code) 1c

587 with nogil: 1c

588 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1c

589 cynvjitlink.nvJitLinkGetLinkedPtx(c_nvjitlink_h, c_code_ptr)) 1c

590 else: 

591 c_culink_state = as_cu(self._culink_handle) 

592 try: 

593 with nogil: 

594 HANDLE_RETURN(cydriver.cuLinkComplete(c_culink_state, &c_cubin_out, &c_output_size)) 

595 except CUDAError as e: 

596 Linker_annotate_error_log(self, e) 

597 raise 

598 code = (<char*>c_cubin_out)[:c_output_size] 

599  

600 # Linking is complete; cache the decoded log strings and release 

601 # the driver's raw bytearray buffers (no longer written to). 

602 self._info_log = self.get_info_log() 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

603 self._error_log = self.get_error_log() 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

604 self._drv_log_bufs = None 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

605  

606 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj

607  

608  

609cdef inline void Linker_annotate_error_log(Linker self, object e): 

610 """Annotate a CUDAError with the driver linker error log.""" 

611 error_log = self.get_error_log() 

612 if error_log: 

613 e.args = (e.args[0] + f"\nLinker error log: {error_log}", *e.args[1:]) 

614  

615  

616# ============================================================================= 

617# Private implementation: module-level state and initialization 

618# ============================================================================= 

619  

620# TODO: revisit this treatment for py313t builds 

621_driver = None # populated if nvJitLink cannot be used 

622_inited = False 

623_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver() 

624  

625# Input type mappings populated by _lazy_init() with C-level enum ints. 

626_nvjitlink_input_types = None 

627_driver_input_types = None 

628  

629  

630def _nvjitlink_has_version_symbol(nvjitlink) -> bool: 

631 # This condition is equivalent to testing for version >= 12.3 

632 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion")) 

633  

634  

635# Note: this function is reused in the tests 

636def _decide_nvjitlink_or_driver() -> bool: 

637 """Return True if falling back to the cuLink* driver APIs.""" 

638 global _driver, _use_nvjitlink_backend 

639 if _use_nvjitlink_backend is not None: 1MUW

640 return not _use_nvjitlink_backend 

641  

642 warn_txt_common = ( 

643 "the driver APIs will be used instead, which do not support" 1MUW

644 " minor version compatibility or linking LTO IRs." 

645 " For best results, consider upgrading to a recent version of" 

646 ) 

647  

648 nvjitlink_module = _optional_cuda_import( 1MUW

649 "cuda.bindings.nvjitlink", 

650 probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load 1MUW

651 ) 

652 if nvjitlink_module is None: 1MU

653 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 1U

654 else: 

655 from cuda.bindings._internal import nvjitlink 

656  

657 if _nvjitlink_has_version_symbol(nvjitlink): 

658 _use_nvjitlink_backend = True 

659 return False # Use nvjitlink 

660 warn_txt = ( 

661 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)." 

662 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink." 

663 ) 

664  

665 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 1U

666 _use_nvjitlink_backend = False 1U

667 _driver = driver 1U

668 return True 1U

669  

670  

671def _lazy_init(): 

672 global _inited, _nvjitlink_input_types, _driver_input_types 

673 if _inited: 1MNmQRuPvclVSTOdaefghibj

674 return 1MNmQRuPvclVSTOdaefghibj

675  

676 _decide_nvjitlink_or_driver() 

677 if _use_nvjitlink_backend: 

678 _nvjitlink_input_types = { 

679 "ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX, 

680 "cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN, 

681 "fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN, 

682 "ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR, 

683 "object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT, 

684 "library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY, 

685 } 

686 else: 

687 _driver_input_types = { 

688 "ptx": <int>cydriver.CU_JIT_INPUT_PTX, 

689 "cubin": <int>cydriver.CU_JIT_INPUT_CUBIN, 

690 "fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY, 

691 "object": <int>cydriver.CU_JIT_INPUT_OBJECT, 

692 "library": <int>cydriver.CU_JIT_INPUT_LIBRARY, 

693 } 

694 _inited = True