Coverage for cuda / core / _linker.pyx: 64.44%
360 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4"""Linking machinery for combining object codes.
6This module provides :class:`Linker` for linking one or more
7:class:`~cuda.core.ObjectCode` objects, with :class:`LinkerOptions` for
8configuration.
9"""
11from __future__ import annotations
13from cpython.bytearray cimport PyByteArray_AS_STRING
14from libc.stdint cimport intptr_t, uint32_t
15from libcpp.vector cimport vector
16from cuda.bindings cimport cydriver
17from cuda.bindings cimport cynvjitlink
19from ._resource_handles cimport (
20 as_cu,
21 as_py,
22 create_culink_handle,
23 create_nvjitlink_handle,
24)
25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, HANDLE_RETURN_NVJITLINK
27import sys
28from dataclasses import dataclass
29from typing import Union
30from warnings import warn
32from cuda.pathfinder._optional_cuda_import import _optional_cuda_import
33from cuda.core._device import Device
34from cuda.core._module import ObjectCode
35from cuda.core._utils.clear_error_support import assert_type
36from cuda.core._utils.cuda_utils import (
37 CUDAError,
38 check_or_create_options,
39 driver,
40 is_sequence,
41)
43ctypedef const char* const_char_ptr
44ctypedef void* void_ptr
46__all__ = ["Linker", "LinkerOptions"]
48LinkerHandleT = Union["cuda.bindings.nvjitlink.nvJitLinkHandle", "cuda.bindings.driver.CUlinkState"]
51# =============================================================================
52# Principal class
53# =============================================================================
55cdef class Linker:
56 """Represent a linking machinery to link one or more object codes into
57 :class:`~cuda.core.ObjectCode`.
59 This object provides a unified interface to multiple underlying
60 linker libraries (such as nvJitLink or cuLink* from the CUDA driver).
62 Parameters
63 ----------
64 object_codes : :class:`~cuda.core.ObjectCode`
65 One or more ObjectCode objects to be linked.
66 options : :class:`LinkerOptions`, optional
67 Options for the linker. If not provided, default options will be used.
68 """
70 def __init__(self, *object_codes: ObjectCode, options: "LinkerOptions" = None):
71 Linker_init(self, object_codes, options) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
73 def link(self, target_type) -> ObjectCode:
74 """Link the provided object codes into a single output of the specified target type.
76 Parameters
77 ----------
78 target_type : str
79 The type of the target output. Must be either "cubin" or "ptx".
81 Returns
82 -------
83 :class:`~cuda.core.ObjectCode`
84 The linked object code of the specified target type.
86 .. note::
88 Ensure that input object codes were compiled with appropriate
89 flags for linking (e.g., relocatable device code enabled).
90 """
91 return Linker_link(self, target_type) 1NmLrABsCtDnkoEFGHIJwpxyqzKuPvcldaefghibj
93 def get_error_log(self) -> str:
94 """Get the error log generated by the linker.
96 Returns
97 -------
98 str
99 The error log.
100 """
101 # After link(), the decoded log is cached here.
102 if self._error_log is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
103 return self._error_log 1l
104 cdef cynvjitlink.nvJitLinkHandle c_h
105 cdef size_t c_log_size = 0 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
106 cdef char* c_log_ptr
107 if self._use_nvjitlink: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
108 c_h = as_cu(self._nvjitlink_handle) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
109 cynvjitlink.nvJitLinkGetErrorLogSize(c_h, &c_log_size) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
110 log = bytearray(c_log_size) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
111 if c_log_size > 0: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
112 c_log_ptr = <char*>(<bytearray>log) 1N
113 cynvjitlink.nvJitLinkGetErrorLog(c_h, c_log_ptr) 1N
114 return log.decode("utf-8", errors="backslashreplace") 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
115 else:
116 return (<bytearray>self._drv_log_bufs[2]).decode(
117 "utf-8", errors="backslashreplace").rstrip('\x00')
119 def get_info_log(self) -> str:
120 """Get the info log generated by the linker.
122 Returns
123 -------
124 str
125 The info log.
126 """
127 # After link(), the decoded log is cached here.
128 if self._info_log is not None: 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
129 return self._info_log 1ml
130 cdef cynvjitlink.nvJitLinkHandle c_h
131 cdef size_t c_log_size = 0 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
132 cdef char* c_log_ptr
133 if self._use_nvjitlink: 1MmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
134 c_h = as_cu(self._nvjitlink_handle) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
135 cynvjitlink.nvJitLinkGetInfoLogSize(c_h, &c_log_size) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
136 log = bytearray(c_log_size) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
137 if c_log_size > 0: 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
138 c_log_ptr = <char*>(<bytearray>log) 1rstnkcab
139 cynvjitlink.nvJitLinkGetInfoLog(c_h, c_log_ptr) 1rstnkcab
140 return log.decode("utf-8", errors="backslashreplace") 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
141 else:
142 return (<bytearray>self._drv_log_bufs[0]).decode(
143 "utf-8", errors="backslashreplace").rstrip('\x00')
145 def close(self):
146 """Destroy this linker."""
147 if self._use_nvjitlink: 1Odaefghibj
148 self._nvjitlink_handle.reset() 1Odaefghibj
149 else:
150 self._culink_handle.reset()
152 @property
153 def handle(self) -> LinkerHandleT:
154 """Return the underlying handle object.
156 .. note::
158 The type of the returned object depends on the backend.
160 .. caution::
162 This handle is a Python object. To get the memory address of the underlying C
163 handle, call ``int(Linker.handle)``.
164 """
165 if self._use_nvjitlink: 1QO
166 return as_py(self._nvjitlink_handle) 1QO
167 else:
168 return as_py(self._culink_handle)
170 @property
171 def backend(self) -> str:
172 """Return this Linker instance's underlying backend."""
173 return "nvJitLink" if self._use_nvjitlink else "driver" 1LrABsCtDnkoEFGHIJwpxyqzKOdaefghibj
176# =============================================================================
177# Supporting classes
178# =============================================================================
180@dataclass
181class LinkerOptions:
182 """Customizable options for configuring :class:`Linker`.
184 Since the linker may choose to use nvJitLink or the driver APIs as the linking backend,
185 not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
186 or not installed, the driver APIs (cuLink) will be used instead.
188 Attributes
189 ----------
190 name : str, optional
191 Name of the linker. If the linking succeeds, the name is passed down to the generated `ObjectCode`.
192 arch : str, optional
193 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or
194 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture
195 will be used.
196 max_register_count : int, optional
197 Maximum register count.
198 time : bool, optional
199 Print timing information to the info log.
200 Default: False.
201 verbose : bool, optional
202 Print verbose messages to the info log.
203 Default: False.
204 link_time_optimization : bool, optional
205 Perform link time optimization.
206 Default: False.
207 ptx : bool, optional
208 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``.
209 Default: False.
210 optimization_level : int, optional
211 Set optimization level. Only 0 and 3 are accepted.
212 debug : bool, optional
213 Generate debug information.
214 Default: False.
215 lineinfo : bool, optional
216 Generate line information.
217 Default: False.
218 ftz : bool, optional
219 Flush denormal values to zero.
220 Default: False.
221 prec_div : bool, optional
222 Use precise division.
223 Default: True.
224 prec_sqrt : bool, optional
225 Use precise square root.
226 Default: True.
227 fma : bool, optional
228 Use fast multiply-add.
229 Default: True.
230 kernels_used : [str | tuple[str] | list[str]], optional
231 Pass a kernel or sequence of kernels that are used; any not in the list can be removed.
232 variables_used : [str | tuple[str] | list[str]], optional
233 Pass a variable or sequence of variables that are used; any not in the list can be removed.
234 optimize_unused_variables : bool, optional
235 Assume that if a variable is not referenced in device code, it can be removed.
236 Default: False.
237 ptxas_options : [str | tuple[str] | list[str]], optional
238 Pass options to PTXAS.
239 split_compile : int, optional
240 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split
241 compilation (default).
242 Default: 1.
243 split_compile_extended : int, optional
244 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value.
245 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This
246 option can potentially impact performance of the compiled binary.
247 Default: 1.
248 no_cache : bool, optional
249 Do not cache the intermediate steps of nvJitLink.
250 Default: False.
251 """
253 name: str | None = "<default linker>"
254 arch: str | None = None
255 max_register_count: int | None = None
256 time: bool | None = None
257 verbose: bool | None = None
258 link_time_optimization: bool | None = None
259 ptx: bool | None = None
260 optimization_level: int | None = None
261 debug: bool | None = None
262 lineinfo: bool | None = None
263 ftz: bool | None = None
264 prec_div: bool | None = None
265 prec_sqrt: bool | None = None
266 fma: bool | None = None
267 kernels_used: str | tuple[str] | list[str] | None = None
268 variables_used: str | tuple[str] | list[str] | None = None
269 optimize_unused_variables: bool | None = None
270 ptxas_options: str | tuple[str] | list[str] | None = None
271 split_compile: int | None = None
272 split_compile_extended: int | None = None
273 no_cache: bool | None = None
275 def __post_init__(self):
276 _lazy_init() 1MNmQRuPvclVSTOdaefghibj
277 self._name = self.name.encode() 1MNmQRuPvclVSTOdaefghibj
279 def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> list[bytes] | list[str]:
280 options = [] 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
282 if self.arch is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
283 options.append(f"-arch={self.arch}") 1NmQrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
284 else:
285 options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 1L
286 if self.max_register_count is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
287 options.append(f"-maxrregcount={self.max_register_count}") 1ASd
288 if self.time is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
289 options.append("-time") 1tb
290 if self.verbose: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
291 options.append("-verbose") 1r
292 if self.link_time_optimization: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
293 options.append("-lto") 1c
294 if self.ptx: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
295 options.append("-ptx") 1Rc
296 if self.optimization_level is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
297 options.append(f"-O{self.optimization_level}") 1B
298 if self.debug: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
299 options.append("-g") 1sSTa
300 if self.lineinfo: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
301 options.append("-lineinfo") 1CTe
302 if self.ftz is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
303 options.append(f"-ftz={'true' if self.ftz else 'false'}") 1GSf
304 if self.prec_div is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
305 options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 1Hg
306 if self.prec_sqrt is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
307 options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 1Ih
308 if self.fma is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
309 options.append(f"-fma={'true' if self.fma else 'false'}") 1Ji
310 if self.kernels_used is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
311 if isinstance(self.kernels_used, str): 1wpx
312 options.append(f"-kernels-used={self.kernels_used}") 1w
313 elif isinstance(self.kernels_used, list): 1px
314 for kernel in self.kernels_used: 1p
315 options.append(f"-kernels-used={kernel}") 1p
316 if self.variables_used is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
317 if isinstance(self.variables_used, str): 1yqz
318 options.append(f"-variables-used={self.variables_used}") 1y
319 elif isinstance(self.variables_used, list): 1qz
320 for variable in self.variables_used: 1q
321 options.append(f"-variables-used={variable}") 1q
322 if self.optimize_unused_variables is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
323 options.append("-optimize-unused-variables") 1D
324 if self.ptxas_options is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
325 if isinstance(self.ptxas_options, str): 1nko
326 options.append(f"-Xptxas={self.ptxas_options}") 1n
327 elif is_sequence(self.ptxas_options): 1ko
328 for opt in self.ptxas_options: 1ko
329 options.append(f"-Xptxas={opt}") 1ko
330 if self.split_compile is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
331 options.append(f"-split-compile={self.split_compile}") 1Ej
332 if self.split_compile_extended is not None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
333 options.append(f"-split-compile-extended={self.split_compile_extended}") 1F
334 if self.no_cache is True: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
335 options.append("-no-cache") 1K
337 if as_bytes: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSTOdaefghibj
338 return [o.encode() for o in options] 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclSOdaefghibj
339 else:
340 return options 1T
342 def _prepare_driver_options(self) -> tuple[list, list]:
343 formatted_options = []
344 option_keys = []
346 # allocate a fixed-sized buffer for each info/error log
347 size = 4194304
348 formatted_options.extend((bytearray(size), size, bytearray(size), size))
349 option_keys.extend(
350 (
351 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
352 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
353 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
354 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
355 )
356 )
358 if self.arch is not None:
359 arch = self.arch.split("_")[-1].upper()
360 formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
361 option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
362 if self.max_register_count is not None:
363 formatted_options.append(self.max_register_count)
364 option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
365 if self.time is not None:
366 raise ValueError("time option is not supported by the driver API")
367 if self.verbose:
368 formatted_options.append(1)
369 option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
370 if self.link_time_optimization:
371 formatted_options.append(1)
372 option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
373 if self.ptx:
374 raise ValueError("ptx option is not supported by the driver API")
375 if self.optimization_level is not None:
376 formatted_options.append(self.optimization_level)
377 option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
378 if self.debug:
379 formatted_options.append(1)
380 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
381 if self.lineinfo:
382 formatted_options.append(1)
383 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
384 if self.ftz is not None:
385 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
386 if self.prec_div is not None:
387 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
388 if self.prec_sqrt is not None:
389 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
390 if self.fma is not None:
391 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3)
392 if self.kernels_used is not None:
393 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
394 if self.variables_used is not None:
395 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
396 if self.optimize_unused_variables is not None:
397 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3)
398 if self.ptxas_options is not None:
399 raise ValueError("ptxas_options option is not supported by the driver API")
400 if self.split_compile is not None:
401 raise ValueError("split_compile option is not supported by the driver API")
402 if self.split_compile_extended is not None:
403 raise ValueError("split_compile_extended option is not supported by the driver API")
404 if self.no_cache is True:
405 formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
406 option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
408 return formatted_options, option_keys
410 def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]:
411 """Convert linker options to bytes format for the nvjitlink backend.
413 Parameters
414 ----------
415 backend : str, optional
416 The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink".
418 Returns
419 -------
420 list[bytes]
421 List of option strings encoded as bytes.
423 Raises
424 ------
425 ValueError
426 If an unsupported backend is specified.
427 RuntimeError
428 If nvJitLink backend is not available.
429 """
430 backend = backend.lower() 1VS
431 if backend != "nvjitlink": 1VS
432 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") 1V
433 if not _use_nvjitlink_backend: 1S
434 raise RuntimeError("nvJitLink backend is not available")
435 return self._prepare_nvjitlink_options(as_bytes=True) 1S
438# =============================================================================
439# Private implementation: cdef inline helpers
440# =============================================================================
442cdef inline int Linker_init(Linker self, tuple object_codes, object options) except -1:
443 """Initialize a Linker instance."""
444 if len(object_codes) == 0: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
445 raise ValueError("At least one ObjectCode object must be provided")
447 cdef cynvjitlink.nvJitLinkHandle c_raw_nvjitlink
448 cdef cydriver.CUlinkState c_raw_culink
449 cdef Py_ssize_t c_num_opts, i
450 cdef vector[const_char_ptr] c_str_opts
451 cdef vector[cydriver.CUjit_option] c_jit_keys
452 cdef vector[void_ptr] c_jit_values
454 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
456 if _use_nvjitlink_backend: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
457 self._use_nvjitlink = True 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
458 options_bytes = options._prepare_nvjitlink_options(as_bytes=True) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
459 c_num_opts = len(options_bytes) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
460 c_str_opts.resize(c_num_opts) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
461 for i in range(c_num_opts): 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
462 c_str_opts[i] = <const char*>(<bytes>options_bytes[i]) 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
463 with nogil: 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
464 HANDLE_RETURN_NVJITLINK(NULL, cynvjitlink.nvJitLinkCreate( 1NmQLrABsCtDnkoEFGHIJwpxyqzKRuPvclOdaefghibj
465 &c_raw_nvjitlink, <uint32_t>c_num_opts, c_str_opts.data()))
466 self._nvjitlink_handle = create_nvjitlink_handle(c_raw_nvjitlink) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
467 else:
468 self._use_nvjitlink = False
469 formatted_options, option_keys = options._prepare_driver_options()
470 # Keep the formatted_options list alive: it contains bytearrays that
471 # the driver writes into via raw pointers during linking operations.
472 self._drv_log_bufs = formatted_options
473 c_num_opts = len(option_keys)
474 c_jit_keys.resize(c_num_opts)
475 c_jit_values.resize(c_num_opts)
476 for i in range(c_num_opts):
477 c_jit_keys[i] = <cydriver.CUjit_option><int>option_keys[i]
478 val = formatted_options[i]
479 if isinstance(val, bytearray):
480 c_jit_values[i] = <void*>PyByteArray_AS_STRING(val)
481 else:
482 c_jit_values[i] = <void*><intptr_t>int(val)
483 try:
484 with nogil:
485 HANDLE_RETURN(cydriver.cuLinkCreate(
486 <unsigned int>c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink))
487 except CUDAError as e:
488 Linker_annotate_error_log(self, e)
489 raise
490 self._culink_handle = create_culink_handle(c_raw_culink)
492 for code in object_codes: 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
493 assert_type(code, ObjectCode) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
494 Linker_add_code_object(self, code) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
495 return 0 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
498cdef inline void Linker_add_code_object(Linker self, object object_code) except *:
499 """Add a single ObjectCode to the linker."""
500 data = object_code.code 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
501 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
502 cdef cydriver.CUlinkState c_culink_state
503 cdef cynvjitlink.nvJitLinkInputType c_nv_input_type
504 cdef cydriver.CUjitInputType c_drv_input_type
505 cdef const char* c_data_ptr
506 cdef size_t c_data_size
507 cdef const char* c_name_ptr
508 cdef const char* c_file_ptr
510 name_bytes = f"{object_code.name}".encode() 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
511 c_name_ptr = <const char*>name_bytes 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
513 input_types = _nvjitlink_input_types if self._use_nvjitlink else _driver_input_types 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
514 py_input_type = input_types.get(object_code.code_type) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
515 if py_input_type is None: 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
516 raise ValueError(f"Unknown code_type associated with ObjectCode: {object_code.code_type}")
518 if self._use_nvjitlink: 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
519 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
520 c_nv_input_type = <cynvjitlink.nvJitLinkInputType><int>py_input_type 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
521 if isinstance(data, bytes): 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
522 c_data_ptr = <const char*>(<bytes>data) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
523 c_data_size = len(data) 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
524 with nogil: 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
525 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddData( 1NmQLrABsCtDnkoEFGHIJwpxyqzKuPvclOdaefghibj
526 c_nvjitlink_h, c_nv_input_type, <const void*>c_data_ptr, c_data_size, c_name_ptr))
527 elif isinstance(data, str):
528 file_bytes = data.encode()
529 c_file_ptr = <const char*>file_bytes
530 with nogil:
531 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddFile(
532 c_nvjitlink_h, c_nv_input_type, c_file_ptr))
533 else:
534 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
535 else:
536 c_culink_state = as_cu(self._culink_handle)
537 c_drv_input_type = <cydriver.CUjitInputType><int>py_input_type
538 try:
539 if isinstance(data, bytes):
540 c_data_ptr = <const char*>(<bytes>data)
541 c_data_size = len(data)
542 with nogil:
543 HANDLE_RETURN(cydriver.cuLinkAddData(
544 c_culink_state, c_drv_input_type, <void*>c_data_ptr, c_data_size, c_name_ptr,
545 0, NULL, NULL))
546 elif isinstance(data, str):
547 file_bytes = data.encode()
548 c_file_ptr = <const char*>file_bytes
549 with nogil:
550 HANDLE_RETURN(cydriver.cuLinkAddFile(
551 c_culink_state, c_drv_input_type, c_file_ptr, 0, NULL, NULL))
552 else:
553 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
554 except CUDAError as e:
555 Linker_annotate_error_log(self, e)
556 raise
559cdef inline object Linker_link(Linker self, str target_type):
560 """Complete linking and return the result as ObjectCode."""
561 if target_type not in ("cubin", "ptx"): 1NmLrABsCtDnkoEFGHIJwpxyqzKuPvcldaefghibj
562 raise ValueError(f"Unsupported target type: {target_type}") 1P
564 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
565 cdef cydriver.CUlinkState c_culink_state
566 cdef size_t c_output_size = 0 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
567 cdef char* c_code_ptr
568 cdef void* c_cubin_out = NULL 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
570 if self._use_nvjitlink: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
571 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
572 with nogil: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
573 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkComplete(c_nvjitlink_h)) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
574 if target_type == "cubin": 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
575 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
576 cynvjitlink.nvJitLinkGetLinkedCubinSize(c_nvjitlink_h, &c_output_size)) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
577 code = bytearray(c_output_size) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
578 c_code_ptr = <char*>(<bytearray>code) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
579 with nogil: 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
580 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
581 cynvjitlink.nvJitLinkGetLinkedCubin(c_nvjitlink_h, c_code_ptr)) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
582 else:
583 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1c
584 cynvjitlink.nvJitLinkGetLinkedPtxSize(c_nvjitlink_h, &c_output_size)) 1c
585 code = bytearray(c_output_size) 1c
586 c_code_ptr = <char*>(<bytearray>code) 1c
587 with nogil: 1c
588 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1c
589 cynvjitlink.nvJitLinkGetLinkedPtx(c_nvjitlink_h, c_code_ptr)) 1c
590 else:
591 c_culink_state = as_cu(self._culink_handle)
592 try:
593 with nogil:
594 HANDLE_RETURN(cydriver.cuLinkComplete(c_culink_state, &c_cubin_out, &c_output_size))
595 except CUDAError as e:
596 Linker_annotate_error_log(self, e)
597 raise
598 code = (<char*>c_cubin_out)[:c_output_size]
600 # Linking is complete; cache the decoded log strings and release
601 # the driver's raw bytearray buffers (no longer written to).
602 self._info_log = self.get_info_log() 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
603 self._error_log = self.get_error_log() 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
604 self._drv_log_bufs = None 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
606 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
609cdef inline void Linker_annotate_error_log(Linker self, object e):
610 """Annotate a CUDAError with the driver linker error log."""
611 error_log = self.get_error_log()
612 if error_log:
613 e.args = (e.args[0] + f"\nLinker error log: {error_log}", *e.args[1:])
616# =============================================================================
617# Private implementation: module-level state and initialization
618# =============================================================================
620# TODO: revisit this treatment for py313t builds
621_driver = None # populated if nvJitLink cannot be used
622_inited = False
623_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()
625# Input type mappings populated by _lazy_init() with C-level enum ints.
626_nvjitlink_input_types = None
627_driver_input_types = None
630def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
631 # This condition is equivalent to testing for version >= 12.3
632 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
635# Note: this function is reused in the tests
636def _decide_nvjitlink_or_driver() -> bool:
637 """Return True if falling back to the cuLink* driver APIs."""
638 global _driver, _use_nvjitlink_backend
639 if _use_nvjitlink_backend is not None: 1MUW
640 return not _use_nvjitlink_backend
642 warn_txt_common = (
643 "the driver APIs will be used instead, which do not support" 1MUW
644 " minor version compatibility or linking LTO IRs."
645 " For best results, consider upgrading to a recent version of"
646 )
648 nvjitlink_module = _optional_cuda_import( 1MUW
649 "cuda.bindings.nvjitlink",
650 probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load 1MUW
651 )
652 if nvjitlink_module is None: 1MU
653 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 1U
654 else:
655 from cuda.bindings._internal import nvjitlink
657 if _nvjitlink_has_version_symbol(nvjitlink):
658 _use_nvjitlink_backend = True
659 return False # Use nvjitlink
660 warn_txt = (
661 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
662 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
663 )
665 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 1U
666 _use_nvjitlink_backend = False 1U
667 _driver = driver 1U
668 return True 1U
671def _lazy_init():
672 global _inited, _nvjitlink_input_types, _driver_input_types
673 if _inited: 1MNmQRuPvclVSTOdaefghibj
674 return 1MNmQRuPvclVSTOdaefghibj
676 _decide_nvjitlink_or_driver()
677 if _use_nvjitlink_backend:
678 _nvjitlink_input_types = {
679 "ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX,
680 "cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN,
681 "fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN,
682 "ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR,
683 "object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT,
684 "library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY,
685 }
686 else:
687 _driver_input_types = {
688 "ptx": <int>cydriver.CU_JIT_INPUT_PTX,
689 "cubin": <int>cydriver.CU_JIT_INPUT_CUBIN,
690 "fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY,
691 "object": <int>cydriver.CU_JIT_INPUT_OBJECT,
692 "library": <int>cydriver.CU_JIT_INPUT_LIBRARY,
693 }
694 _inited = True