Coverage for cuda / core / _linker.pyx: 64.64%
362 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4"""Linking machinery for combining object codes.
6This module provides :class:`Linker` for linking one or more
7:class:`~cuda.core.ObjectCode` objects, with :class:`LinkerOptions` for
8configuration.
9"""
11from __future__ import annotations
13from cpython.bytearray cimport PyByteArray_AS_STRING
14from libc.stdint cimport intptr_t, uint32_t
15from libcpp.vector cimport vector
16from cuda.bindings cimport cydriver
17from cuda.bindings cimport cynvjitlink
19from ._resource_handles cimport (
20 as_cu,
21 as_py,
22 create_culink_handle,
23 create_nvjitlink_handle,
24)
25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, HANDLE_RETURN_NVJITLINK
27import sys
28from dataclasses import dataclass
29from typing import Union
30from warnings import warn
32from cuda.pathfinder._optional_cuda_import import _optional_cuda_import
33from cuda.core._device import Device
34from cuda.core._module import ObjectCode
35from cuda.core._utils.clear_error_support import assert_type
36from cuda.core._utils.cuda_utils import (
37 CUDAError,
38 check_or_create_options,
39 driver,
40 is_sequence,
41)
42from cuda.core.typing import CompilerBackendType
44ctypedef const char* const_char_ptr
45ctypedef void* void_ptr
47__all__ = ["Linker", "LinkerOptions"]
49LinkerHandleT = Union["cuda.bindings.nvjitlink.nvJitLinkHandle", "cuda.bindings.driver.CUlinkState"]
52# =============================================================================
53# Principal class
54# =============================================================================
56cdef class Linker:
57 """Represent a linking machinery to link one or more object codes into
58 :class:`~cuda.core.ObjectCode`.
60 This object provides a unified interface to multiple underlying
61 linker libraries (such as nvJitLink or cuLink* from the CUDA driver).
63 Parameters
64 ----------
65 object_codes : :class:`~cuda.core.ObjectCode`
66 One or more ObjectCode objects to be linked.
67 options : :class:`LinkerOptions`, optional
68 Options for the linker. If not provided, default options will be used.
69 """
71 def __init__(self, *object_codes: ObjectCode, options: "LinkerOptions" = None):
72 Linker_init(self, object_codes, options) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
74 def link(self, target_type: ObjectCodeFormatType | str) -> ObjectCode:
75 """Link the provided object codes into a single output of the specified target type.
77 Parameters
78 ----------
79 target_type : ObjectCodeFormatType | str
80 The type of the target output. Must be either "cubin" or "ptx".
82 Returns
83 -------
84 :class:`~cuda.core.ObjectCode`
85 The linked object code of the specified target type.
87 .. note::
89 Ensure that input object codes were compiled with appropriate
90 flags for linking (e.g., relocatable device code enabled).
91 """
92 return Linker_link(self, str(target_type)) 1OtMpzAqBrCuljDEFGHIvmwxnyJKQLksocadefghbi
94 def get_error_log(self) -> str:
95 """Get the error log generated by the linker.
97 Returns
98 -------
99 str
100 The error log.
101 """
102 # After link(), the decoded log is cached here.
103 if self._error_log is not None: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
104 return self._error_log 1s
105 cdef cynvjitlink.nvJitLinkHandle c_h
106 cdef size_t c_log_size = 0 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
107 cdef char* c_log_ptr
108 if self._use_nvjitlink: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
109 c_h = as_cu(self._nvjitlink_handle) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
110 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetErrorLogSize(c_h, &c_log_size)) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
111 log = bytearray(c_log_size) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
112 if c_log_size > 0: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
113 c_log_ptr = <char*>(<bytearray>log) 1O
114 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetErrorLog(c_h, c_log_ptr)) 1O
115 return log.decode("utf-8", errors="backslashreplace") 1NOtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
116 else:
117 return (<bytearray>self._drv_log_bufs[2]).decode(
118 "utf-8", errors="backslashreplace").rstrip('\x00')
120 def get_info_log(self) -> str:
121 """Get the info log generated by the linker.
123 Returns
124 -------
125 str
126 The info log.
127 """
128 # After link(), the decoded log is cached here.
129 if self._info_log is not None: 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
130 return self._info_log 1ts
131 cdef cynvjitlink.nvJitLinkHandle c_h
132 cdef size_t c_log_size = 0 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
133 cdef char* c_log_ptr
134 if self._use_nvjitlink: 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
135 c_h = as_cu(self._nvjitlink_handle) 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
136 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetInfoLogSize(c_h, &c_log_size)) 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
137 log = bytearray(c_log_size) 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
138 if c_log_size > 0: 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
139 c_log_ptr = <char*>(<bytearray>log) 1pqrjkab
140 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetInfoLog(c_h, c_log_ptr)) 1pqrjkab
141 return log.decode("utf-8", errors="backslashreplace") 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
142 else:
143 return (<bytearray>self._drv_log_bufs[0]).decode(
144 "utf-8", errors="backslashreplace").rstrip('\x00')
146 def close(self):
147 """Destroy this linker."""
148 if self._use_nvjitlink: 1Pcadefghbi
149 self._nvjitlink_handle.reset() 1Pcadefghbi
150 else:
151 self._culink_handle.reset()
153 @property
154 def handle(self) -> LinkerHandleT:
155 """Return the underlying handle object.
157 .. note::
159 The type of the returned object depends on the backend.
161 .. caution::
163 This handle is a Python object. To get the memory address of the underlying C
164 handle, call ``int(Linker.handle)``.
165 """
166 if self._use_nvjitlink: 1SP
167 return as_py(self._nvjitlink_handle) 1SP
168 else:
169 return as_py(self._culink_handle)
171 @classmethod
172 def which_backend(cls) -> CompilerBackendType:
173 """Return which linking backend will be used.
175 Returns :attr:`~CompilerBackendType.NVJITLINK` when the nvJitLink
176 library is available and meets the minimum version requirement,
177 otherwise :attr:`~CompilerBackendType.DRIVER`.
179 .. note::
181 Prefer letting :class:`Linker` decide. Query ``which_backend()``
182 only when you need to dispatch based on input format (for
183 example: choose PTX vs. LTOIR before constructing a
184 ``Linker``). The returned value names an implementation
185 detail whose support matrix may shift across CTK releases.
186 """
187 return CompilerBackendType.DRIVER if _decide_nvjitlink_or_driver() else CompilerBackendType.NVJITLINK 1?Z0MpzAqBrCuljDEFGHIvmwxnyJRoPcadefghbi
190# =============================================================================
191# Supporting classes
192# =============================================================================
194@dataclass
195class LinkerOptions:
196 """Customizable options for configuring :class:`Linker`.
198 Since the linker may choose to use nvJitLink or the driver APIs as the linking backend,
199 not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
200 or not installed, the driver APIs (cuLink) will be used instead.
202 Attributes
203 ----------
204 name : str, optional
205 Name of the linker. If the linking succeeds, the name is passed down to the generated :class:`ObjectCode`.
206 arch : str, optional
207 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or
208 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture
209 will be used.
210 max_register_count : int, optional
211 Maximum register count.
212 time : bool, optional
213 Print timing information to the info log.
214 Default: False.
215 verbose : bool, optional
216 Print verbose messages to the info log.
217 Default: False.
218 link_time_optimization : bool, optional
219 Perform link time optimization.
220 Default: False.
221 ptx : bool, optional
222 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``.
223 Default: False.
224 optimization_level : int, optional
225 Set optimization level. Only 0 and 3 are accepted.
226 debug : bool, optional
227 Generate debug information.
228 Default: False.
229 lineinfo : bool, optional
230 Generate line information.
231 Default: False.
232 ftz : bool, optional
233 Flush denormal values to zero.
234 Default: False.
235 prec_div : bool, optional
236 Use precise division.
237 Default: True.
238 prec_sqrt : bool, optional
239 Use precise square root.
240 Default: True.
241 fma : bool, optional
242 Use fast multiply-add.
243 Default: True.
244 kernels_used : [str | tuple[str] | list[str]], optional
245 Pass a kernel or sequence of kernels that are used; any not in the list can be removed.
246 variables_used : [str | tuple[str] | list[str]], optional
247 Pass a variable or sequence of variables that are used; any not in the list can be removed.
248 optimize_unused_variables : bool, optional
249 Assume that if a variable is not referenced in device code, it can be removed.
250 Default: False.
251 ptxas_options : [str | tuple[str] | list[str]], optional
252 Pass options to PTXAS.
253 split_compile : int, optional
254 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split
255 compilation (default).
256 Default: 1.
257 split_compile_extended : int, optional
258 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value.
259 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This
260 option can potentially impact performance of the compiled binary.
261 Default: 1.
262 no_cache : bool, optional
263 Do not cache the intermediate steps of nvJitLink.
264 Default: False.
265 """
267 name: str | None = "<default linker>"
268 arch: str | None = None
269 max_register_count: int | None = None
270 time: bool | None = None
271 verbose: bool | None = None
272 link_time_optimization: bool | None = None
273 ptx: bool | None = None
274 optimization_level: int | None = None
275 debug: bool | None = None
276 lineinfo: bool | None = None
277 ftz: bool | None = None
278 prec_div: bool | None = None
279 prec_sqrt: bool | None = None
280 fma: bool | None = None
281 kernels_used: str | tuple[str] | list[str] | None = None
282 variables_used: str | tuple[str] | list[str] | None = None
283 optimize_unused_variables: bool | None = None
284 ptxas_options: str | tuple[str] | list[str] | None = None
285 split_compile: int | None = None
286 split_compile_extended: int | None = None
287 no_cache: bool | None = None
289 def __post_init__(self):
290 _lazy_init() 1NOtSTKQLksXUVRoPcadefghbi
291 self._name = self.name.encode() 1NOtSTKQLksXUVRoPcadefghbi
293 def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> list[bytes] | list[str]:
294 options = [] 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
296 if self.arch is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
297 options.append(f"-arch={self.arch}") 1OtSpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
298 else:
299 options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 1M
300 if self.max_register_count is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
301 options.append(f"-maxrregcount={self.max_register_count}") 1zUc
302 if self.time is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
303 options.append("-time") 1rb
304 if self.verbose: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
305 options.append("-verbose") 1p
306 if self.link_time_optimization: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
307 options.append("-lto") 1k
308 if self.ptx: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
309 options.append("-ptx") 1Tk
310 if self.optimization_level is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
311 options.append(f"-O{self.optimization_level}") 1A
312 if self.debug: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
313 options.append("-g") 1qUVa
314 if self.lineinfo: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
315 options.append("-lineinfo") 1BVd
316 if self.ftz is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
317 options.append(f"-ftz={'true' if self.ftz else 'false'}") 1FUe
318 if self.prec_div is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
319 options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 1Gf
320 if self.prec_sqrt is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
321 options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 1Hg
322 if self.fma is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
323 options.append(f"-fma={'true' if self.fma else 'false'}") 1Ih
324 if self.kernels_used is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
325 if isinstance(self.kernels_used, str): 1vmw
326 options.append(f"-kernels-used={self.kernels_used}") 1v
327 elif isinstance(self.kernels_used, list): 1mw
328 for kernel in self.kernels_used: 1m
329 options.append(f"-kernels-used={kernel}") 1m
330 if self.variables_used is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
331 if isinstance(self.variables_used, str): 1xny
332 options.append(f"-variables-used={self.variables_used}") 1x
333 elif isinstance(self.variables_used, list): 1ny
334 for variable in self.variables_used: 1n
335 options.append(f"-variables-used={variable}") 1n
336 if self.optimize_unused_variables is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
337 options.append("-optimize-unused-variables") 1C
338 if self.ptxas_options is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
339 if isinstance(self.ptxas_options, str): 1ulj
340 options.append(f"-Xptxas={self.ptxas_options}") 1u
341 elif is_sequence(self.ptxas_options): 1lj
342 for opt in self.ptxas_options: 1lj
343 options.append(f"-Xptxas={opt}") 1lj
344 if self.split_compile is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
345 options.append(f"-split-compile={self.split_compile}") 1Di
346 if self.split_compile_extended is not None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
347 options.append(f"-split-compile-extended={self.split_compile_extended}") 1E
348 if self.no_cache is True: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
349 options.append("-no-cache") 1J
351 if as_bytes: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksUVRoPcadefghbi
352 return [o.encode() for o in options] 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksURoPcadefghbi
353 else:
354 return options 1V
356 def _prepare_driver_options(self) -> tuple[list, list]:
357 formatted_options = []
358 option_keys = []
360 # allocate a fixed-sized buffer for each info/error log
361 size = 4194304
362 formatted_options.extend((bytearray(size), size, bytearray(size), size))
363 option_keys.extend(
364 (
365 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
366 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
367 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
368 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
369 )
370 )
372 if self.arch is not None:
373 arch = self.arch.split("_")[-1].upper()
374 formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
375 option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
376 if self.max_register_count is not None:
377 formatted_options.append(self.max_register_count)
378 option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
379 if self.time is not None:
380 raise ValueError("time option is not supported by the driver API")
381 if self.verbose:
382 formatted_options.append(1)
383 option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
384 if self.link_time_optimization:
385 formatted_options.append(1)
386 option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
387 if self.ptx:
388 raise ValueError("ptx option is not supported by the driver API")
389 if self.optimization_level is not None:
390 formatted_options.append(self.optimization_level)
391 option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
392 if self.debug:
393 formatted_options.append(1)
394 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
395 if self.lineinfo:
396 formatted_options.append(1)
397 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
398 if self.ftz is not None:
399 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
400 if self.prec_div is not None:
401 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
402 if self.prec_sqrt is not None:
403 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
404 if self.fma is not None:
405 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3)
406 if self.kernels_used is not None:
407 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
408 if self.variables_used is not None:
409 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
410 if self.optimize_unused_variables is not None:
411 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3)
412 if self.ptxas_options is not None:
413 raise ValueError("ptxas_options option is not supported by the driver API")
414 if self.split_compile is not None:
415 raise ValueError("split_compile option is not supported by the driver API")
416 if self.split_compile_extended is not None:
417 raise ValueError("split_compile_extended option is not supported by the driver API")
418 if self.no_cache is True:
419 formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
420 option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
422 return formatted_options, option_keys
424 def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]:
425 """Convert linker options to bytes format for the nvjitlink backend.
427 Parameters
428 ----------
429 backend : str, optional
430 The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink".
432 Returns
433 -------
434 list[bytes]
435 List of option strings encoded as bytes.
437 Raises
438 ------
439 ValueError
440 If an unsupported backend is specified.
441 RuntimeError
442 If nvJitLink backend is not available.
443 """
444 backend = backend.lower() 1XU
445 if backend != "nvjitlink": 1XU
446 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") 1X
447 if not _use_nvjitlink_backend: 1U
448 raise RuntimeError("nvJitLink backend is not available")
449 return self._prepare_nvjitlink_options(as_bytes=True) 1U
452# =============================================================================
453# Private implementation: cdef inline helpers
454# =============================================================================
456cdef inline int Linker_init(Linker self, tuple object_codes, object options) except -1:
457 """Initialize a Linker instance."""
458 if len(object_codes) == 0: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
459 raise ValueError("At least one ObjectCode object must be provided")
461 cdef cynvjitlink.nvJitLinkHandle c_raw_nvjitlink
462 cdef cydriver.CUlinkState c_raw_culink
463 cdef Py_ssize_t c_num_opts, i
464 cdef vector[const_char_ptr] c_str_opts
465 cdef vector[cydriver.CUjit_option] c_jit_keys
466 cdef vector[void_ptr] c_jit_values
468 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
470 if _use_nvjitlink_backend: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
471 self._use_nvjitlink = True 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
472 options_bytes = options._prepare_nvjitlink_options(as_bytes=True) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
473 c_num_opts = len(options_bytes) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
474 c_str_opts.resize(c_num_opts) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
475 for i in range(c_num_opts): 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
476 c_str_opts[i] = <const char*>(<bytes>options_bytes[i]) 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
477 with nogil: 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
478 HANDLE_RETURN_NVJITLINK(NULL, cynvjitlink.nvJitLinkCreate( 1OtSMpzAqBrCuljDEFGHIvmwxnyJTKQLksRoPcadefghbi
479 &c_raw_nvjitlink, <uint32_t>c_num_opts, c_str_opts.data()))
480 self._nvjitlink_handle = create_nvjitlink_handle(c_raw_nvjitlink) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
481 else:
482 self._use_nvjitlink = False
483 formatted_options, option_keys = options._prepare_driver_options()
484 # Keep the formatted_options list alive: it contains bytearrays that
485 # the driver writes into via raw pointers during linking operations.
486 self._drv_log_bufs = formatted_options
487 c_num_opts = len(option_keys)
488 c_jit_keys.resize(c_num_opts)
489 c_jit_values.resize(c_num_opts)
490 for i in range(c_num_opts):
491 c_jit_keys[i] = <cydriver.CUjit_option><int>option_keys[i]
492 val = formatted_options[i]
493 if isinstance(val, bytearray):
494 c_jit_values[i] = <void*>PyByteArray_AS_STRING(val)
495 else:
496 c_jit_values[i] = <void*><intptr_t>int(val)
497 try:
498 with nogil:
499 HANDLE_RETURN(cydriver.cuLinkCreate(
500 <unsigned int>c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink))
501 except CUDAError as e:
502 Linker_annotate_error_log(self, e)
503 raise
504 self._culink_handle = create_culink_handle(c_raw_culink)
506 for code in object_codes: 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
507 assert_type(code, ObjectCode) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
508 Linker_add_code_object(self, code) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
509 return 0 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
512cdef inline void Linker_add_code_object(Linker self, object object_code) except *:
513 """Add a single ObjectCode to the linker."""
514 data = object_code.code 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
515 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
516 cdef cydriver.CUlinkState c_culink_state
517 cdef cynvjitlink.nvJitLinkInputType c_nv_input_type
518 cdef cydriver.CUjitInputType c_drv_input_type
519 cdef const char* c_data_ptr
520 cdef size_t c_data_size
521 cdef const char* c_name_ptr
522 cdef const char* c_file_ptr
524 name_bytes = f"{object_code.name}".encode() 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
525 c_name_ptr = <const char*>name_bytes 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
527 input_types = _nvjitlink_input_types if self._use_nvjitlink else _driver_input_types 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
528 py_input_type = input_types.get(object_code.code_type) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
529 if py_input_type is None: 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
530 raise ValueError(f"Unknown code_type associated with ObjectCode: {object_code.code_type}")
532 if self._use_nvjitlink: 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
533 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
534 c_nv_input_type = <cynvjitlink.nvJitLinkInputType><int>py_input_type 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
535 if isinstance(data, bytes): 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
536 c_data_ptr = <const char*>(<bytes>data) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
537 c_data_size = len(data) 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
538 with nogil: 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
539 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddData( 1OtSMpzAqBrCuljDEFGHIvmwxnyJKQLksRoPcadefghbi
540 c_nvjitlink_h, c_nv_input_type, <const void*>c_data_ptr, c_data_size, c_name_ptr))
541 elif isinstance(data, str):
542 file_bytes = data.encode()
543 c_file_ptr = <const char*>file_bytes
544 with nogil:
545 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddFile(
546 c_nvjitlink_h, c_nv_input_type, c_file_ptr))
547 else:
548 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
549 else:
550 c_culink_state = as_cu(self._culink_handle)
551 c_drv_input_type = <cydriver.CUjitInputType><int>py_input_type
552 try:
553 if isinstance(data, bytes):
554 c_data_ptr = <const char*>(<bytes>data)
555 c_data_size = len(data)
556 with nogil:
557 HANDLE_RETURN(cydriver.cuLinkAddData(
558 c_culink_state, c_drv_input_type, <void*>c_data_ptr, c_data_size, c_name_ptr,
559 0, NULL, NULL))
560 elif isinstance(data, str):
561 file_bytes = data.encode()
562 c_file_ptr = <const char*>file_bytes
563 with nogil:
564 HANDLE_RETURN(cydriver.cuLinkAddFile(
565 c_culink_state, c_drv_input_type, c_file_ptr, 0, NULL, NULL))
566 else:
567 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
568 except CUDAError as e:
569 Linker_annotate_error_log(self, e)
570 raise
573cdef inline object Linker_link(Linker self, str target_type):
574 """Complete linking and return the result as ObjectCode."""
575 if target_type not in ("cubin", "ptx"): 1OtMpzAqBrCuljDEFGHIvmwxnyJKQLksocadefghbi
576 raise ValueError(f"Unsupported target type: {target_type}") 1Q
578 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
579 cdef cydriver.CUlinkState c_culink_state
580 cdef size_t c_output_size = 0 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
581 cdef char* c_code_ptr
582 cdef void* c_cubin_out = NULL 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
584 if self._use_nvjitlink: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
585 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
586 with nogil: 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
587 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkComplete(c_nvjitlink_h)) 1OtMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
588 if target_type == "cubin": 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
589 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi
590 cynvjitlink.nvJitLinkGetLinkedCubinSize(c_nvjitlink_h, &c_output_size)) 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi
591 code = bytearray(c_output_size) 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi
592 c_code_ptr = <char*>(<bytearray>code) 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi
593 with nogil: 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi
594 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi
595 cynvjitlink.nvJitLinkGetLinkedCubin(c_nvjitlink_h, c_code_ptr)) 1tMpzAqBrCuljDEFGHIvmwxnyJKLsocadefghbi
596 else:
597 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1k
598 cynvjitlink.nvJitLinkGetLinkedPtxSize(c_nvjitlink_h, &c_output_size)) 1k
599 code = bytearray(c_output_size) 1k
600 c_code_ptr = <char*>(<bytearray>code) 1k
601 with nogil: 1k
602 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1k
603 cynvjitlink.nvJitLinkGetLinkedPtx(c_nvjitlink_h, c_code_ptr)) 1k
604 else:
605 c_culink_state = as_cu(self._culink_handle)
606 try:
607 with nogil:
608 HANDLE_RETURN(cydriver.cuLinkComplete(c_culink_state, &c_cubin_out, &c_output_size))
609 except CUDAError as e:
610 Linker_annotate_error_log(self, e)
611 raise
612 code = (<char*>c_cubin_out)[:c_output_size]
614 # Linking is complete; cache the decoded log strings and release
615 # the driver's raw bytearray buffers (no longer written to).
616 self._info_log = self.get_info_log() 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
617 self._error_log = self.get_error_log() 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
618 self._drv_log_bufs = None 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
620 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 1tMpzAqBrCuljDEFGHIvmwxnyJKLksocadefghbi
623cdef inline void Linker_annotate_error_log(Linker self, object e):
624 """Annotate a CUDAError with the driver linker error log."""
625 error_log = self.get_error_log()
626 if error_log:
627 e.args = (e.args[0] + f"\nLinker error log: {error_log}", *e.args[1:])
630# =============================================================================
631# Private implementation: module-level state and initialization
632# =============================================================================
634# TODO: revisit this treatment for py313t builds
635_driver = None # populated if nvJitLink cannot be used
636_inited = False
637_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()
639# Input type mappings populated by _lazy_init() with C-level enum ints.
640_nvjitlink_input_types = None
641_driver_input_types = None
644def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
645 # This condition is equivalent to testing for version >= 12.3
646 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
649# Note: this function is reused in the tests
650def _decide_nvjitlink_or_driver() -> bool:
651 """Return True if falling back to the cuLink* driver APIs."""
652 global _driver, _use_nvjitlink_backend
653 if _use_nvjitlink_backend is not None: 1NZ0MpzAqBrCuljDEFGHIvmwxnyJRWYoPcadefghbi123456789!#$%'()*+,-./:;=
654 return not _use_nvjitlink_backend 1NZ0MpzAqBrCuljDEFGHIvmwxnyJRoPcadefghbi123456789!#$%'()*+,-./:;=
656 warn_txt_common = (
657 "the driver APIs will be used instead, which do not support" 1NWY
658 " minor version compatibility or linking LTO IRs."
659 " For best results, consider upgrading to a recent version of"
660 )
662 nvjitlink_module = _optional_cuda_import( 1NWY
663 "cuda.bindings.nvjitlink",
664 probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load 1NWY
665 )
666 if nvjitlink_module is None: 1NW
667 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 1W
668 else:
669 from cuda.bindings._internal import nvjitlink
671 if _nvjitlink_has_version_symbol(nvjitlink):
672 _use_nvjitlink_backend = True
673 return False # Use nvjitlink
674 warn_txt = (
675 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
676 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
677 )
679 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 1W
680 _use_nvjitlink_backend = False 1W
681 _driver = driver 1W
682 return True 1W
685def _lazy_init():
686 global _inited, _nvjitlink_input_types, _driver_input_types
687 if _inited: 1NOtSTKQLksXUVRoPcadefghbi
688 return 1NOtSTKQLksXUVRoPcadefghbi
690 _decide_nvjitlink_or_driver()
691 if _use_nvjitlink_backend:
692 _nvjitlink_input_types = {
693 "ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX,
694 "cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN,
695 "fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN,
696 "ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR,
697 "object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT,
698 "library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY,
699 }
700 else:
701 _driver_input_types = {
702 "ptx": <int>cydriver.CU_JIT_INPUT_PTX,
703 "cubin": <int>cydriver.CU_JIT_INPUT_CUBIN,
704 "fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY,
705 "object": <int>cydriver.CU_JIT_INPUT_OBJECT,
706 "library": <int>cydriver.CU_JIT_INPUT_LIBRARY,
707 }
708 _inited = True