Coverage for cuda/core/_linker.pyx: 81.10%
365 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4"""Linking machinery for combining object codes.
6This module provides :class:`Linker` for linking one or more
7:class:`~cuda.core.ObjectCode` objects, with :class:`LinkerOptions` for
8configuration.
9"""
11from __future__ import annotations
13from cpython.bytearray cimport PyByteArray_AS_STRING
14from libc.stdint cimport intptr_t, uint32_t
15from libcpp.vector cimport vector
16from cuda.bindings cimport cydriver
17from cuda.bindings cimport cynvjitlink
19from ._resource_handles cimport (
20 as_cu,
21 as_py,
22 create_culink_handle,
23 create_nvjitlink_handle,
24)
25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, HANDLE_RETURN_NVJITLINK
27import sys
28from dataclasses import dataclass
29from typing import TYPE_CHECKING, Union
30from warnings import warn
32from cuda.pathfinder._optional_cuda_import import _optional_cuda_import
33from cuda.core._device import Device
34from cuda.core._module import ObjectCode
35from cuda.core._utils.clear_error_support import assert_type
36from cuda.core._utils.cuda_utils import (
37 CUDAError,
38 check_or_create_options,
39 driver,
40 is_sequence,
41)
42from cuda.core.typing import CompilerBackendType, ObjectCodeFormatType
44if TYPE_CHECKING:
45 import cuda.bindings.driver # no-cython-lint
46 import cuda.bindings.nvjitlink # no-cython-lint
48# Module-level annotations to ensure stubgen-pyx keeps the above imports in
49# the generated `.pyi` so that the LinkerHandleT forward references resolve.
50# These names are not assigned, so they only affect __annotations__.
51_keep_driver_in_stub: "cuda.bindings.driver.CUlinkState"
52_keep_nvjitlink_in_stub: "cuda.bindings.nvjitlink.nvJitLinkHandle"
54ctypedef const char* const_char_ptr
55ctypedef void* void_ptr
57__all__ = ["Linker", "LinkerOptions"]
59LinkerHandleT = Union["cuda.bindings.nvjitlink.nvJitLinkHandle", "cuda.bindings.driver.CUlinkState"]
62# =============================================================================
63# Principal class
64# =============================================================================
66cdef class Linker:
67 """Represent a linking machinery to link one or more object codes into
68 :class:`~cuda.core.ObjectCode`.
70 This object provides a unified interface to multiple underlying
71 linker libraries (such as nvJitLink or cuLink* from the CUDA driver).
73 Parameters
74 ----------
75 object_codes : :class:`~cuda.core.ObjectCode`
76 One or more ObjectCode objects to be linked.
77 options : :class:`LinkerOptions`, optional
78 Options for the linker. If not provided, default options will be used.
79 """
81 def __init__(self, *object_codes: ObjectCode, options: LinkerOptions | None = None):
82 Linker_init(self, object_codes, options) 1$OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
84 def link(self, target_type: ObjectCodeFormatType | str) -> ObjectCode:
85 """Link the provided object codes into a single output of the specified target type.
87 Parameters
88 ----------
89 target_type : ObjectCodeFormatType | str
90 The type of the target output. Must be either "cubin" or "ptx".
92 Returns
93 -------
94 :class:`~cuda.core.ObjectCode`
95 The linked object code of the specified target type.
97 .. note::
99 Ensure that input object codes were compiled with appropriate
100 flags for linking (e.g., relocatable device code enabled).
101 """
102 return Linker_link(self, str(target_type)) 1OtMpzAqBrCumkDEFGHIvnwxoyJKQLlscdefghiabj
104 def get_error_log(self) -> str:
105 """Get the error log generated by the linker.
107 Returns
108 -------
109 str
110 The error log.
111 """
112 # After link(), the decoded log is cached here.
113 if self._error_log is not None: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
114 return self._error_log 1s
115 cdef cynvjitlink.nvJitLinkHandle c_h
116 cdef size_t c_log_size = 0 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
117 cdef char* c_log_ptr
118 if self._use_nvjitlink: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
119 c_h = as_cu(self._nvjitlink_handle) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
120 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetErrorLogSize(c_h, &c_log_size)) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
121 log = bytearray(c_log_size) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
122 if c_log_size > 0: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
123 c_log_ptr = <char*>(<bytearray>log) 1O
124 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetErrorLog(c_h, c_log_ptr)) 1O
125 return log.decode("utf-8", errors="backslashreplace") 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
126 else:
127 return (<bytearray>self._drv_log_bufs[2]).decode(
128 "utf-8", errors="backslashreplace").rstrip('\x00')
130 def get_info_log(self) -> str:
131 """Get the info log generated by the linker.
133 Returns
134 -------
135 str
136 The info log.
137 """
138 # After link(), the decoded log is cached here.
139 if self._info_log is not None: 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
140 return self._info_log 1ts
141 cdef cynvjitlink.nvJitLinkHandle c_h
142 cdef size_t c_log_size = 0 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
143 cdef char* c_log_ptr
144 if self._use_nvjitlink: 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
145 c_h = as_cu(self._nvjitlink_handle) 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
146 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetInfoLogSize(c_h, &c_log_size)) 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
147 log = bytearray(c_log_size) 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
148 if c_log_size > 0: 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
149 c_log_ptr = <char*>(<bytearray>log) 1pqrklab
150 HANDLE_RETURN_NVJITLINK(c_h, cynvjitlink.nvJitLinkGetInfoLog(c_h, c_log_ptr)) 1pqrklab
151 return log.decode("utf-8", errors="backslashreplace") 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
152 else:
153 return (<bytearray>self._drv_log_bufs[0]).decode(
154 "utf-8", errors="backslashreplace").rstrip('\x00')
156 def close(self) -> None:
157 """Destroy this linker."""
158 if self._use_nvjitlink: 1Pcdefghiabj
159 self._nvjitlink_handle.reset() 1Pcdefghiabj
160 else:
161 self._culink_handle.reset()
163 @property
164 def handle(self) -> LinkerHandleT:
165 """Return the underlying handle object.
167 .. note::
169 The type of the returned object depends on the backend.
171 .. caution::
173 This handle is a Python object. To get the memory address of the underlying C
174 handle, call ``int(Linker.handle)``.
175 """
176 if self._use_nvjitlink: 1RP
177 return as_py(self._nvjitlink_handle) 1RP
178 else:
179 return as_py(self._culink_handle)
181 @classmethod
182 def which_backend(cls) -> CompilerBackendType:
183 """Return which linking backend will be used.
185 Returns :attr:`~CompilerBackendType.NVJITLINK` when the nvJitLink
186 library is available and meets the minimum version requirement,
187 otherwise :attr:`~CompilerBackendType.DRIVER`.
189 .. note::
191 Prefer letting :class:`Linker` decide. Query ``which_backend()``
192 only when you need to dispatch based on input format (for
193 example: choose PTX vs. LTOIR before constructing a
194 ``Linker``). The returned value names an implementation
195 detail whose support matrix may shift across CTK releases.
196 """
197 return CompilerBackendType.DRIVER if _decide_nvjitlink_or_driver() else CompilerBackendType.NVJITLINK 2db% ' M p z A q B r C u m k D E F G H I v n w x o y J P c d e f g h i a b j
200# =============================================================================
201# Supporting classes
202# =============================================================================
204@dataclass
205class LinkerOptions:
206 """Customizable options for configuring :class:`Linker`.
208 Since the linker may choose to use nvJitLink or the driver APIs as the linking backend,
209 not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
210 or not installed, the driver APIs (cuLink) will be used instead.
212 Attributes
213 ----------
214 name : str, optional
215 Name of the linker. If the linking succeeds, the name is passed down to the generated :class:`ObjectCode`.
216 arch : str, optional
217 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or
218 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture
219 will be used.
220 max_register_count : int, optional
221 Maximum register count.
222 time : bool, optional
223 Print timing information to the info log.
224 Default: False.
225 verbose : bool, optional
226 Print verbose messages to the info log.
227 Default: False.
228 link_time_optimization : bool, optional
229 Perform link time optimization.
230 Default: False.
231 ptx : bool, optional
232 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``.
233 Default: False.
234 optimization_level : int, optional
235 Set optimization level. Only 0 and 3 are accepted.
236 debug : bool, optional
237 Generate debug information.
238 Default: False.
239 lineinfo : bool, optional
240 Generate line information.
241 Default: False.
242 ftz : bool, optional
243 Flush denormal values to zero.
244 Default: False.
245 prec_div : bool, optional
246 Use precise division.
247 Default: True.
248 prec_sqrt : bool, optional
249 Use precise square root.
250 Default: True.
251 fma : bool, optional
252 Use fast multiply-add.
253 Default: True.
254 kernels_used : [str | tuple[str] | list[str]], optional
255 Pass a kernel or sequence of kernels that are used; any not in the list can be removed.
256 variables_used : [str | tuple[str] | list[str]], optional
257 Pass a variable or sequence of variables that are used; any not in the list can be removed.
258 optimize_unused_variables : bool, optional
259 Assume that if a variable is not referenced in device code, it can be removed.
260 Default: False.
261 ptxas_options : [str | tuple[str] | list[str]], optional
262 Pass options to PTXAS.
263 split_compile : int, optional
264 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split
265 compilation (default).
266 Default: 1.
267 split_compile_extended : int, optional
268 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value.
269 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This
270 option can potentially impact performance of the compiled binary.
271 Default: 1.
272 no_cache : bool, optional
273 Do not cache the intermediate steps of nvJitLink.
274 Default: False.
275 """
277 name: str | None = "<default linker>"
278 arch: str | None = None
279 max_register_count: int | None = None
280 time: bool | None = None
281 verbose: bool | None = None
282 link_time_optimization: bool | None = None
283 ptx: bool | None = None
284 optimization_level: int | None = None
285 debug: bool | None = None
286 lineinfo: bool | None = None
287 ftz: bool | None = None
288 prec_div: bool | None = None
289 prec_sqrt: bool | None = None
290 fma: bool | None = None
291 kernels_used: str | tuple[str] | list[str] | None = None
292 variables_used: str | tuple[str] | list[str] | None = None
293 optimize_unused_variables: bool | None = None
294 ptxas_options: str | tuple[str] | list[str] | None = None
295 split_compile: int | None = None
296 split_compile_extended: int | None = None
297 no_cache: bool | None = None
298 numba_debug: bool | None = None
300 def __post_init__(self) -> None:
301 _lazy_init() 1N9OtRTKQLls!U5SVWXYZ0176432Pcdefghiabj
302 self._name = self.name.encode() 1N9OtRTKQLls!U5SVWXYZ0176432Pcdefghiabj
304 def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> list[bytes] | list[str]:
305 options = [] 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
307 if self.arch is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
308 options.append(f"-arch={self.arch}") 1OtRpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
309 else:
310 options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 1M
311 if self.max_register_count is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
312 options.append(f"-maxrregcount={self.max_register_count}") 1zUc
313 if self.time is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
314 options.append("-time") 1rb
315 if self.verbose: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
316 options.append("-verbose") 1p
317 if self.link_time_optimization: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
318 options.append("-lto") 1l
319 if self.ptx: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
320 options.append("-ptx") 1Tl
321 if self.optimization_level is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
322 options.append(f"-O{self.optimization_level}") 1A
323 if self.debug: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
324 options.append("-g") 1qU5da
325 if self.lineinfo: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
326 options.append("-lineinfo") 1B5e
327 if self.ftz is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
328 options.append(f"-ftz={'true' if self.ftz else 'false'}") 1FUf
329 if self.prec_div is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
330 options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 1Gg
331 if self.prec_sqrt is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
332 options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 1Hh
333 if self.fma is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
334 options.append(f"-fma={'true' if self.fma else 'false'}") 1Ii
335 if self.kernels_used is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
336 if isinstance(self.kernels_used, str): 1vnw
337 options.append(f"-kernels-used={self.kernels_used}") 1v
338 elif isinstance(self.kernels_used, list): 1nw
339 for kernel in self.kernels_used: 1n
340 options.append(f"-kernels-used={kernel}") 1n
341 if self.variables_used is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
342 if isinstance(self.variables_used, str): 1xoy
343 options.append(f"-variables-used={self.variables_used}") 1x
344 elif isinstance(self.variables_used, list): 1oy
345 for variable in self.variables_used: 1o
346 options.append(f"-variables-used={variable}") 1o
347 if self.optimize_unused_variables is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
348 options.append("-optimize-unused-variables") 1C
349 if self.ptxas_options is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
350 if isinstance(self.ptxas_options, str): 1umk
351 options.append(f"-Xptxas={self.ptxas_options}") 1u
352 elif is_sequence(self.ptxas_options): 1mk
353 for opt in self.ptxas_options: 1mk
354 options.append(f"-Xptxas={opt}") 1mk
355 if self.split_compile is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
356 options.append(f"-split-compile={self.split_compile}") 1Dj
357 if self.split_compile_extended is not None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
358 options.append(f"-split-compile-extended={self.split_compile_extended}") 1E
359 if self.no_cache is True: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
360 options.append("-no-cache") 1J
362 if as_bytes: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsU5Pcdefghiabj
363 return [o.encode() for o in options] 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsUPcdefghiabj
364 else:
365 return options 15
367 def _prepare_driver_options(self) -> tuple[list[object], list[object]]:
368 formatted_options = [] 1SVWXYZ0176432
369 option_keys = [] 1SVWXYZ0176432
371 # allocate a fixed-sized buffer for each info/error log
372 size = 4194304 1SVWXYZ0176432
373 formatted_options.extend((bytearray(size), size, bytearray(size), size)) 1SVWXYZ0176432
374 option_keys.extend( 1SVWXYZ0176432
375 (
376 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER, 1SVWXYZ0176432
377 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, 1SVWXYZ0176432
378 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER, 1SVWXYZ0176432
379 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, 1SVWXYZ0176432
380 )
381 )
383 if self.arch is not None: 1SVWXYZ0176432
384 arch = self.arch.split("_")[-1].upper() 1S
385 formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}")) 1S
386 option_keys.append(_driver.CUjit_option.CU_JIT_TARGET) 1S
387 if self.max_register_count is not None: 1SVWXYZ0176432
388 formatted_options.append(self.max_register_count) 1S
389 option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS) 1S
390 if self.time is not None: 1SVWXYZ0176432
391 raise ValueError("time option is not supported by the driver API") 17
392 if self.verbose: 1SVWXYZ016432
393 formatted_options.append(1) 1S
394 option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE) 1S
395 if self.link_time_optimization: 1SVWXYZ016432
396 formatted_options.append(1) 1S
397 option_keys.append(_driver.CUjit_option.CU_JIT_LTO) 1S
398 if self.ptx: 1SVWXYZ016432
399 raise ValueError("ptx option is not supported by the driver API") 16
400 if self.optimization_level is not None: 1SVWXYZ01432
401 formatted_options.append(self.optimization_level) 1S
402 option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL) 1S
403 if self.debug: 1SVWXYZ01432
404 formatted_options.append(1) 1S
405 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO) 1S
406 if self.lineinfo: 1SVWXYZ01432
407 formatted_options.append(1) 1S
408 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO) 1S
409 if self.ftz is not None: 1SVWXYZ01432
410 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1V
411 if self.prec_div is not None: 1SVWXYZ01432
412 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1W
413 if self.prec_sqrt is not None: 1SVWXYZ01432
414 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1X
415 if self.fma is not None: 1SVWXYZ01432
416 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1Y
417 if self.kernels_used is not None: 1SVWXYZ01432
418 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 1Z
419 if self.variables_used is not None: 1SVWXYZ01432
420 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3) 10
421 if self.optimize_unused_variables is not None: 1SVWXYZ01432
422 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3) 11
423 if self.ptxas_options is not None: 1SVWXYZ01432
424 raise ValueError("ptxas_options option is not supported by the driver API") 14
425 if self.split_compile is not None: 1SVWXYZ0132
426 raise ValueError("split_compile option is not supported by the driver API") 13
427 if self.split_compile_extended is not None: 1SVWXYZ012
428 raise ValueError("split_compile_extended option is not supported by the driver API") 12
429 if self.no_cache is True: 1SVWXYZ01
430 formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE) 1S
431 option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE) 1S
433 return formatted_options, option_keys 1SVWXYZ01
435 def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]:
436 """Convert linker options to bytes format for the nvjitlink backend.
438 Parameters
439 ----------
440 backend : str, optional
441 The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink".
443 Returns
444 -------
445 list[bytes]
446 List of option strings encoded as bytes.
448 Raises
449 ------
450 ValueError
451 If an unsupported backend is specified.
452 RuntimeError
453 If nvJitLink backend is not available.
454 """
455 backend = backend.lower() 19!U
456 if backend != "nvjitlink": 19!U
457 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") 1!
458 if not _use_nvjitlink_backend: 19U
459 raise RuntimeError("nvJitLink backend is not available") 19
460 return self._prepare_nvjitlink_options(as_bytes=True) 1U
463# =============================================================================
464# Private implementation: cdef inline helpers
465# =============================================================================
467cdef inline int Linker_init(Linker self, tuple object_codes, object options) except -1:
468 """Initialize a Linker instance."""
469 if len(object_codes) == 0: 1$OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
470 raise ValueError("At least one ObjectCode object must be provided") 1$
472 cdef cynvjitlink.nvJitLinkHandle c_raw_nvjitlink
473 cdef cydriver.CUlinkState c_raw_culink
474 cdef Py_ssize_t c_num_opts, i
475 cdef vector[const_char_ptr] c_str_opts
476 cdef vector[cydriver.CUjit_option] c_jit_keys
477 cdef vector[void_ptr] c_jit_values
479 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
481 if _use_nvjitlink_backend: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
482 self._use_nvjitlink = True 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
483 options_bytes = options._prepare_nvjitlink_options(as_bytes=True) 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
484 c_num_opts = len(options_bytes) 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
485 c_str_opts.resize(c_num_opts) 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
486 for i in range(c_num_opts): 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
487 c_str_opts[i] = <const char*>(<bytes>options_bytes[i]) 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
488 with nogil: 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
489 HANDLE_RETURN_NVJITLINK(NULL, cynvjitlink.nvJitLinkCreate( 1OtRMpzAqBrCumkDEFGHIvnwxoyJTKQLlsPcdefghiabj
490 &c_raw_nvjitlink, <uint32_t>c_num_opts, c_str_opts.data()))
491 self._nvjitlink_handle = create_nvjitlink_handle(c_raw_nvjitlink) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
492 else:
493 self._use_nvjitlink = False
494 formatted_options, option_keys = options._prepare_driver_options()
495 # Keep the formatted_options list alive: it contains bytearrays that
496 # the driver writes into via raw pointers during linking operations.
497 self._drv_log_bufs = formatted_options
498 c_num_opts = len(option_keys)
499 c_jit_keys.resize(c_num_opts)
500 c_jit_values.resize(c_num_opts)
501 for i in range(c_num_opts):
502 c_jit_keys[i] = <cydriver.CUjit_option><int>option_keys[i]
503 val = formatted_options[i]
504 if isinstance(val, bytearray):
505 c_jit_values[i] = <void*>PyByteArray_AS_STRING(val)
506 else:
507 c_jit_values[i] = <void*><intptr_t>int(val)
508 try:
509 with nogil:
510 HANDLE_RETURN(cydriver.cuLinkCreate(
511 <unsigned int>c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink))
512 except CUDAError as e:
513 Linker_annotate_error_log(self, e)
514 raise
515 self._culink_handle = create_culink_handle(c_raw_culink)
517 for code in object_codes: 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
518 assert_type(code, ObjectCode) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
519 Linker_add_code_object(self, code) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
520 return 0 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
523cdef inline void Linker_add_code_object(Linker self, object object_code) except *:
524 """Add a single ObjectCode to the linker."""
525 data = object_code.code 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
526 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
527 cdef cydriver.CUlinkState c_culink_state
528 cdef cynvjitlink.nvJitLinkInputType c_nv_input_type
529 cdef cydriver.CUjitInputType c_drv_input_type
530 cdef const char* c_data_ptr
531 cdef size_t c_data_size
532 cdef const char* c_name_ptr
533 cdef const char* c_file_ptr
535 name_bytes = f"{object_code.name}".encode() 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
536 c_name_ptr = <const char*>name_bytes 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
538 input_types = _nvjitlink_input_types if self._use_nvjitlink else _driver_input_types 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
539 py_input_type = input_types.get(object_code.code_type) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
540 if py_input_type is None: 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
541 raise ValueError(f"Unknown code_type associated with ObjectCode: {object_code.code_type}")
543 if self._use_nvjitlink: 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
544 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
545 c_nv_input_type = <cynvjitlink.nvJitLinkInputType><int>py_input_type 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
546 if isinstance(data, bytes): 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
547 c_data_ptr = <const char*>(<bytes>data) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
548 c_data_size = len(data) 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
549 with nogil: 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
550 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddData( 1OtRMpzAqBrCumkDEFGHIvnwxoyJKQLlsPcdefghiabj
551 c_nvjitlink_h, c_nv_input_type, <const void*>c_data_ptr, c_data_size, c_name_ptr))
552 elif isinstance(data, str):
553 file_bytes = data.encode()
554 c_file_ptr = <const char*>file_bytes
555 with nogil:
556 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddFile(
557 c_nvjitlink_h, c_nv_input_type, c_file_ptr))
558 else:
559 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
560 else:
561 c_culink_state = as_cu(self._culink_handle)
562 c_drv_input_type = <cydriver.CUjitInputType><int>py_input_type
563 try:
564 if isinstance(data, bytes):
565 c_data_ptr = <const char*>(<bytes>data)
566 c_data_size = len(data)
567 with nogil:
568 HANDLE_RETURN(cydriver.cuLinkAddData(
569 c_culink_state, c_drv_input_type, <void*>c_data_ptr, c_data_size, c_name_ptr,
570 0, NULL, NULL))
571 elif isinstance(data, str):
572 file_bytes = data.encode()
573 c_file_ptr = <const char*>file_bytes
574 with nogil:
575 HANDLE_RETURN(cydriver.cuLinkAddFile(
576 c_culink_state, c_drv_input_type, c_file_ptr, 0, NULL, NULL))
577 else:
578 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
579 except CUDAError as e:
580 Linker_annotate_error_log(self, e)
581 raise
584cdef inline object Linker_link(Linker self, str target_type):
585 """Complete linking and return the result as ObjectCode."""
586 if target_type not in ("cubin", "ptx"): 1OtMpzAqBrCumkDEFGHIvnwxoyJKQLlscdefghiabj
587 raise ValueError(f"Unsupported target type: {target_type}") 1Q
589 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
590 cdef cydriver.CUlinkState c_culink_state
591 cdef size_t c_output_size = 0 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
592 cdef char* c_code_ptr
593 cdef void* c_cubin_out = NULL 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
595 if self._use_nvjitlink: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
596 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
597 with nogil: 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
598 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkComplete(c_nvjitlink_h)) 1OtMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
599 if target_type == "cubin": 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
600 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj
601 cynvjitlink.nvJitLinkGetLinkedCubinSize(c_nvjitlink_h, &c_output_size)) 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj
602 code = bytearray(c_output_size) 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj
603 c_code_ptr = <char*>(<bytearray>code) 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj
604 with nogil: 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj
605 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj
606 cynvjitlink.nvJitLinkGetLinkedCubin(c_nvjitlink_h, c_code_ptr)) 1tMpzAqBrCumkDEFGHIvnwxoyJKLscdefghiabj
607 else:
608 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1l
609 cynvjitlink.nvJitLinkGetLinkedPtxSize(c_nvjitlink_h, &c_output_size)) 1l
610 code = bytearray(c_output_size) 1l
611 c_code_ptr = <char*>(<bytearray>code) 1l
612 with nogil: 1l
613 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1l
614 cynvjitlink.nvJitLinkGetLinkedPtx(c_nvjitlink_h, c_code_ptr)) 1l
615 else:
616 c_culink_state = as_cu(self._culink_handle)
617 try:
618 with nogil:
619 HANDLE_RETURN(cydriver.cuLinkComplete(c_culink_state, &c_cubin_out, &c_output_size))
620 except CUDAError as e:
621 Linker_annotate_error_log(self, e)
622 raise
623 code = (<char*>c_cubin_out)[:c_output_size]
625 # Linking is complete; cache the decoded log strings and release
626 # the driver's raw bytearray buffers (no longer written to).
627 self._info_log = self.get_info_log() 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
628 self._error_log = self.get_error_log() 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
629 self._drv_log_bufs = None 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
631 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 1tMpzAqBrCumkDEFGHIvnwxoyJKLlscdefghiabj
634cdef inline void Linker_annotate_error_log(Linker self, object e):
635 """Annotate a CUDAError with the driver linker error log."""
636 error_log = self.get_error_log()
637 if error_log:
638 e.args = (e.args[0] + f"\nLinker error log: {error_log}", *e.args[1:])
641# =============================================================================
642# Private implementation: module-level state and initialization
643# =============================================================================
645# TODO: revisit this treatment for py313t builds
646_driver = None # populated if nvJitLink cannot be used
647_inited = False
648_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()
650# Input type mappings populated by _lazy_init() with C-level enum ints.
651_nvjitlink_input_types = None
652_driver_input_types = None
655def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
656 # This condition is equivalent to testing for version >= 12.3
657 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
660# Note: this function is reused in the tests
661def _decide_nvjitlink_or_driver() -> bool:
662 """Return True if falling back to the cuLink* driver APIs."""
663 global _driver, _use_nvjitlink_backend
664 if _use_nvjitlink_backend is not None: 2N % ' M p z A q B r C u m k D E F G H I v n w x o y J 8 # P c d e f g h i a b j ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcb
665 return not _use_nvjitlink_backend 2N % ' M p z A q B r C u m k D E F G H I v n w x o y J P c d e f g h i a b j ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbcb
667 warn_txt_common = (
668 "the driver APIs will be used instead, which do not support" 1N8#
669 " minor version compatibility or linking LTO IRs."
670 " For best results, consider upgrading to a recent version of"
671 )
673 nvjitlink_module = _optional_cuda_import( 1N8#
674 "cuda.bindings.nvjitlink",
675 probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load 1N8#
676 )
677 if nvjitlink_module is None: 1N8
678 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 18
679 else:
680 from cuda.bindings._internal import nvjitlink
682 if _nvjitlink_has_version_symbol(nvjitlink):
683 _use_nvjitlink_backend = True
684 return False # Use nvjitlink
685 warn_txt = (
686 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
687 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
688 )
690 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 18
691 _use_nvjitlink_backend = False 18
692 _driver = driver 18
693 return True 18
696def _lazy_init() -> None:
697 global _inited, _nvjitlink_input_types, _driver_input_types
698 if _inited: 1N9OtRTKQLls!U5SVWXYZ0176432Pcdefghiabj
699 return 1N9OtRTKQLls!U5SVWXYZ0176432Pcdefghiabj
701 _decide_nvjitlink_or_driver()
702 if _use_nvjitlink_backend:
703 _nvjitlink_input_types = {
704 "ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX,
705 "cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN,
706 "fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN,
707 "ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR,
708 "object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT,
709 "library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY,
710 }
711 else:
712 _driver_input_types = {
713 "ptx": <int>cydriver.CU_JIT_INPUT_PTX,
714 "cubin": <int>cydriver.CU_JIT_INPUT_CUBIN,
715 "fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY,
716 "object": <int>cydriver.CU_JIT_INPUT_OBJECT,
717 "library": <int>cydriver.CU_JIT_INPUT_LIBRARY,
718 }
719 _inited = True