Coverage for cuda / core / _linker.pyx: 63.81%
362 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4"""Linking machinery for combining object codes.
6This module provides :class:`Linker` for linking one or more
7:class:`~cuda.core.ObjectCode` objects, with :class:`LinkerOptions` for
8configuration.
9"""
11from __future__ import annotations
13from cpython.bytearray cimport PyByteArray_AS_STRING
14from libc.stdint cimport intptr_t, uint32_t
15from libcpp.vector cimport vector
16from cuda.bindings cimport cydriver
17from cuda.bindings cimport cynvjitlink
19from ._resource_handles cimport (
20 as_cu,
21 as_py,
22 create_culink_handle,
23 create_nvjitlink_handle,
24)
25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, HANDLE_RETURN_NVJITLINK
27import sys
28from dataclasses import dataclass
29from typing import Union
30from warnings import warn
32from cuda.pathfinder._optional_cuda_import import _optional_cuda_import
33from cuda.core._device import Device
34from cuda.core._module import ObjectCode
35from cuda.core._utils.clear_error_support import assert_type
36from cuda.core._utils.cuda_utils import (
37 CUDAError,
38 check_or_create_options,
39 driver,
40 handle_return,
41 is_sequence,
42)
44ctypedef const char* const_char_ptr
45ctypedef void* void_ptr
47__all__ = ["Linker", "LinkerOptions"]
49LinkerHandleT = Union["cuda.bindings.nvjitlink.nvJitLinkHandle", "cuda.bindings.driver.CUlinkState"]
52# =============================================================================
53# Principal class
54# =============================================================================
56cdef class Linker:
57 """Represent a linking machinery to link one or more object codes into
58 :class:`~cuda.core.ObjectCode`.
60 This object provides a unified interface to multiple underlying
61 linker libraries (such as nvJitLink or cuLink* from the CUDA driver).
63 Parameters
64 ----------
65 object_codes : :class:`~cuda.core.ObjectCode`
66 One or more ObjectCode objects to be linked.
67 options : :class:`LinkerOptions`, optional
68 Options for the linker. If not provided, default options will be used.
69 """
71 def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
72 Linker_init(self, object_codes, options) 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
74 def link(self, target_type) -> ObjectCode:
75 """Link the provided object codes into a single output of the specified target type.
77 Parameters
78 ----------
79 target_type : str
80 The type of the target output. Must be either "cubin" or "ptx".
82 Returns
83 -------
84 :class:`~cuda.core.ObjectCode`
85 The linked object code of the specified target type.
87 .. note::
89 Ensure that input object codes were compiled with appropriate
90 flags for linking (e.g., relocatable device code enabled).
91 """
92 return Linker_link(self, target_type) 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
94 def get_error_log(self) -> str:
95 """Get the error log generated by the linker.
97 Returns
98 -------
99 str
100 The error log.
101 """
102 # After link(), the decoded log is cached here.
103 if self._error_log is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
104 return self._error_log 1l
105 cdef cynvjitlink.nvJitLinkHandle c_h
106 cdef size_t c_log_size = 0 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
107 cdef char* c_log_ptr
108 if self._use_nvjitlink: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
109 c_h = as_cu(self._nvjitlink_handle) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
110 cynvjitlink.nvJitLinkGetErrorLogSize(c_h, &c_log_size) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
111 log = bytearray(c_log_size) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
112 if c_log_size > 0: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
113 c_log_ptr = <char*>(<bytearray>log) 1N
114 cynvjitlink.nvJitLinkGetErrorLog(c_h, c_log_ptr) 1N
115 return log.decode("utf-8", errors="backslashreplace") 1MNmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
116 else:
117 return (<bytearray>self._drv_log_bufs[2]).decode(
118 "utf-8", errors="backslashreplace").rstrip('\x00')
120 def get_info_log(self) -> str:
121 """Get the info log generated by the linker.
123 Returns
124 -------
125 str
126 The info log.
127 """
128 # After link(), the decoded log is cached here.
129 if self._info_log is not None: 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
130 return self._info_log 1ml
131 cdef cynvjitlink.nvJitLinkHandle c_h
132 cdef size_t c_log_size = 0 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
133 cdef char* c_log_ptr
134 if self._use_nvjitlink: 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
135 c_h = as_cu(self._nvjitlink_handle) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
136 cynvjitlink.nvJitLinkGetInfoLogSize(c_h, &c_log_size) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
137 log = bytearray(c_log_size) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
138 if c_log_size > 0: 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
139 c_log_ptr = <char*>(<bytearray>log) 1rstnkcab
140 cynvjitlink.nvJitLinkGetInfoLog(c_h, c_log_ptr) 1rstnkcab
141 return log.decode("utf-8", errors="backslashreplace") 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
142 else:
143 return (<bytearray>self._drv_log_bufs[0]).decode(
144 "utf-8", errors="backslashreplace").rstrip('\x00')
146 def close(self):
147 """Destroy this linker."""
148 if self._use_nvjitlink: 1daefghibj
149 self._nvjitlink_handle.reset() 1daefghibj
150 else:
151 self._culink_handle.reset()
153 @property
154 def handle(self) -> LinkerHandleT:
155 """Return the underlying handle object.
157 .. note::
159 The type of the returned object depends on the backend.
161 .. caution::
163 This handle is a Python object. To get the memory address of the underlying C
164 handle, call ``int(Linker.handle)``.
165 """
166 if self._use_nvjitlink:
167 return as_py(self._nvjitlink_handle)
168 else:
169 return as_py(self._culink_handle)
171 @property
172 def backend(self) -> str:
173 """Return this Linker instance's underlying backend."""
174 return "nvJitLink" if self._use_nvjitlink else "driver" 1LrABsCtDnkoEFGHIJwpxyqzKdaefghibj
177# =============================================================================
178# Supporting classes
179# =============================================================================
181@dataclass
182class LinkerOptions:
183 """Customizable options for configuring :class:`Linker`.
185 Since the linker may choose to use nvJitLink or the driver APIs as the linking backend,
186 not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
187 or not installed, the driver APIs (cuLink) will be used instead.
189 Attributes
190 ----------
191 name : str, optional
192 Name of the linker. If the linking succeeds, the name is passed down to the generated `ObjectCode`.
193 arch : str, optional
194 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or
195 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture
196 will be used.
197 max_register_count : int, optional
198 Maximum register count.
199 time : bool, optional
200 Print timing information to the info log.
201 Default: False.
202 verbose : bool, optional
203 Print verbose messages to the info log.
204 Default: False.
205 link_time_optimization : bool, optional
206 Perform link time optimization.
207 Default: False.
208 ptx : bool, optional
209 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``.
210 Default: False.
211 optimization_level : int, optional
212 Set optimization level. Only 0 and 3 are accepted.
213 debug : bool, optional
214 Generate debug information.
215 Default: False.
216 lineinfo : bool, optional
217 Generate line information.
218 Default: False.
219 ftz : bool, optional
220 Flush denormal values to zero.
221 Default: False.
222 prec_div : bool, optional
223 Use precise division.
224 Default: True.
225 prec_sqrt : bool, optional
226 Use precise square root.
227 Default: True.
228 fma : bool, optional
229 Use fast multiply-add.
230 Default: True.
231 kernels_used : [str | tuple[str] | list[str]], optional
232 Pass a kernel or sequence of kernels that are used; any not in the list can be removed.
233 variables_used : [str | tuple[str] | list[str]], optional
234 Pass a variable or sequence of variables that are used; any not in the list can be removed.
235 optimize_unused_variables : bool, optional
236 Assume that if a variable is not referenced in device code, it can be removed.
237 Default: False.
238 ptxas_options : [str | tuple[str] | list[str]], optional
239 Pass options to PTXAS.
240 split_compile : int, optional
241 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split
242 compilation (default).
243 Default: 1.
244 split_compile_extended : int, optional
245 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value.
246 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This
247 option can potentially impact performance of the compiled binary.
248 Default: 1.
249 no_cache : bool, optional
250 Do not cache the intermediate steps of nvJitLink.
251 Default: False.
252 """
254 name: str | None = "<default linker>"
255 arch: str | None = None
256 max_register_count: int | None = None
257 time: bool | None = None
258 verbose: bool | None = None
259 link_time_optimization: bool | None = None
260 ptx: bool | None = None
261 optimization_level: int | None = None
262 debug: bool | None = None
263 lineinfo: bool | None = None
264 ftz: bool | None = None
265 prec_div: bool | None = None
266 prec_sqrt: bool | None = None
267 fma: bool | None = None
268 kernels_used: str | tuple[str] | list[str] | None = None
269 variables_used: str | tuple[str] | list[str] | None = None
270 optimize_unused_variables: bool | None = None
271 ptxas_options: str | tuple[str] | list[str] | None = None
272 split_compile: int | None = None
273 split_compile_extended: int | None = None
274 no_cache: bool | None = None
276 def __post_init__(self):
277 _lazy_init() 1MNmPuOvclSQdaefghibj
278 self._name = self.name.encode() 1MNmPuOvclSQdaefghibj
280 def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> list[bytes] | list[str]:
281 options = [] 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
283 if self.arch is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
284 options.append(f"-arch={self.arch}") 1NmrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
285 else:
286 options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 1L
287 if self.max_register_count is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
288 options.append(f"-maxrregcount={self.max_register_count}") 1AQd
289 if self.time is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
290 options.append("-time") 1tb
291 if self.verbose: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
292 options.append("-verbose") 1r
293 if self.link_time_optimization: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
294 options.append("-lto") 1c
295 if self.ptx: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
296 options.append("-ptx") 1Pc
297 if self.optimization_level is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
298 options.append(f"-O{self.optimization_level}") 1B
299 if self.debug: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
300 options.append("-g") 1sQa
301 if self.lineinfo: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
302 options.append("-lineinfo") 1Ce
303 if self.ftz is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
304 options.append(f"-ftz={'true' if self.ftz else 'false'}") 1GQf
305 if self.prec_div is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
306 options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 1Hg
307 if self.prec_sqrt is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
308 options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 1Ih
309 if self.fma is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
310 options.append(f"-fma={'true' if self.fma else 'false'}") 1Ji
311 if self.kernels_used is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
312 if isinstance(self.kernels_used, str): 1wpx
313 options.append(f"-kernels-used={self.kernels_used}") 1w
314 elif isinstance(self.kernels_used, list): 1px
315 for kernel in self.kernels_used: 1p
316 options.append(f"-kernels-used={kernel}") 1p
317 if self.variables_used is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
318 if isinstance(self.variables_used, str): 1yqz
319 options.append(f"-variables-used={self.variables_used}") 1y
320 elif isinstance(self.variables_used, list): 1qz
321 for variable in self.variables_used: 1q
322 options.append(f"-variables-used={variable}") 1q
323 if self.optimize_unused_variables is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
324 options.append("-optimize-unused-variables") 1D
325 if self.ptxas_options is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
326 if isinstance(self.ptxas_options, str): 1nko
327 options.append(f"-Xptxas={self.ptxas_options}") 1n
328 elif is_sequence(self.ptxas_options): 1ko
329 for opt in self.ptxas_options: 1ko
330 options.append(f"-Xptxas={opt}") 1ko
331 if self.split_compile is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
332 options.append(f"-split-compile={self.split_compile}") 1Ej
333 if self.split_compile_extended is not None: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
334 options.append(f"-split-compile-extended={self.split_compile_extended}") 1F
335 if self.no_cache is True: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
336 options.append("-no-cache") 1K
338 if as_bytes: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
339 return [o.encode() for o in options] 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvclQdaefghibj
340 else:
341 return options
343 def _prepare_driver_options(self) -> tuple[list, list]:
344 formatted_options = []
345 option_keys = []
347 # allocate a fixed-sized buffer for each info/error log
348 size = 4194304
349 formatted_options.extend((bytearray(size), size, bytearray(size), size))
350 option_keys.extend(
351 (
352 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
353 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
354 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
355 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
356 )
357 )
359 if self.arch is not None:
360 arch = self.arch.split("_")[-1].upper()
361 formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
362 option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
363 if self.max_register_count is not None:
364 formatted_options.append(self.max_register_count)
365 option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
366 if self.time is not None:
367 raise ValueError("time option is not supported by the driver API")
368 if self.verbose:
369 formatted_options.append(1)
370 option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
371 if self.link_time_optimization:
372 formatted_options.append(1)
373 option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
374 if self.ptx:
375 raise ValueError("ptx option is not supported by the driver API")
376 if self.optimization_level is not None:
377 formatted_options.append(self.optimization_level)
378 option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
379 if self.debug:
380 formatted_options.append(1)
381 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
382 if self.lineinfo:
383 formatted_options.append(1)
384 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
385 if self.ftz is not None:
386 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
387 if self.prec_div is not None:
388 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
389 if self.prec_sqrt is not None:
390 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
391 if self.fma is not None:
392 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3)
393 if self.kernels_used is not None:
394 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
395 if self.variables_used is not None:
396 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
397 if self.optimize_unused_variables is not None:
398 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3)
399 if self.ptxas_options is not None:
400 raise ValueError("ptxas_options option is not supported by the driver API")
401 if self.split_compile is not None:
402 raise ValueError("split_compile option is not supported by the driver API")
403 if self.split_compile_extended is not None:
404 raise ValueError("split_compile_extended option is not supported by the driver API")
405 if self.no_cache is True:
406 formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
407 option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
409 return formatted_options, option_keys
411 def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]:
412 """Convert linker options to bytes format for the nvjitlink backend.
414 Parameters
415 ----------
416 backend : str, optional
417 The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink".
419 Returns
420 -------
421 list[bytes]
422 List of option strings encoded as bytes.
424 Raises
425 ------
426 ValueError
427 If an unsupported backend is specified.
428 RuntimeError
429 If nvJitLink backend is not available.
430 """
431 backend = backend.lower() 1SQ
432 if backend != "nvjitlink": 1SQ
433 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") 1S
434 if not _use_nvjitlink_backend: 1Q
435 raise RuntimeError("nvJitLink backend is not available")
436 return self._prepare_nvjitlink_options(as_bytes=True) 1Q
439# =============================================================================
440# Private implementation: cdef inline helpers
441# =============================================================================
443cdef inline int Linker_init(Linker self, tuple object_codes, object options) except -1:
444 """Initialize a Linker instance."""
445 if len(object_codes) == 0: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
446 raise ValueError("At least one ObjectCode object must be provided")
448 cdef cynvjitlink.nvJitLinkHandle c_raw_nvjitlink
449 cdef cydriver.CUlinkState c_raw_culink
450 cdef Py_ssize_t c_num_opts, i
451 cdef vector[const_char_ptr] c_str_opts
452 cdef vector[cydriver.CUjit_option] c_jit_keys
453 cdef vector[void_ptr] c_jit_values
455 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
457 if _use_nvjitlink_backend: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
458 self._use_nvjitlink = True 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
459 options_bytes = options._prepare_nvjitlink_options(as_bytes=True) 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
460 c_num_opts = len(options_bytes) 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
461 c_str_opts.resize(c_num_opts) 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
462 for i in range(c_num_opts): 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
463 c_str_opts[i] = <const char*>(<bytes>options_bytes[i]) 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
464 with nogil: 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
465 HANDLE_RETURN_NVJITLINK(NULL, cynvjitlink.nvJitLinkCreate( 1NmLrABsCtDnkoEFGHIJwpxyqzKPuOvcldaefghibj
466 &c_raw_nvjitlink, <uint32_t>c_num_opts, c_str_opts.data()))
467 self._nvjitlink_handle = create_nvjitlink_handle(c_raw_nvjitlink) 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
468 else:
469 self._use_nvjitlink = False
470 formatted_options, option_keys = options._prepare_driver_options()
471 # Keep the formatted_options list alive: it contains bytearrays that
472 # the driver writes into via raw pointers during linking operations.
473 self._drv_log_bufs = formatted_options
474 c_num_opts = len(option_keys)
475 c_jit_keys.resize(c_num_opts)
476 c_jit_values.resize(c_num_opts)
477 for i in range(c_num_opts):
478 c_jit_keys[i] = <cydriver.CUjit_option><int>option_keys[i]
479 val = formatted_options[i]
480 if isinstance(val, bytearray):
481 c_jit_values[i] = <void*>PyByteArray_AS_STRING(val)
482 else:
483 c_jit_values[i] = <void*><intptr_t>int(val)
484 try:
485 with nogil:
486 HANDLE_RETURN(cydriver.cuLinkCreate(
487 <unsigned int>c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink))
488 except CUDAError as e:
489 Linker_annotate_error_log(self, e)
490 raise
491 self._culink_handle = create_culink_handle(c_raw_culink)
493 for code in object_codes: 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
494 assert_type(code, ObjectCode) 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
495 Linker_add_code_object(self, code) 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
496 return 0 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
499cdef inline void Linker_add_code_object(Linker self, object object_code) except *:
500 """Add a single ObjectCode to the linker."""
501 data = object_code.code 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
502 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
503 cdef cydriver.CUlinkState c_culink_state
504 cdef cynvjitlink.nvJitLinkInputType c_nv_input_type
505 cdef cydriver.CUjitInputType c_drv_input_type
506 cdef const char* c_data_ptr
507 cdef size_t c_data_size
508 cdef const char* c_name_ptr
509 cdef const char* c_file_ptr
511 name_bytes = f"{object_code.name}".encode() 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
512 c_name_ptr = <const char*>name_bytes 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
514 input_types = _nvjitlink_input_types if self._use_nvjitlink else _driver_input_types 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
515 py_input_type = input_types.get(object_code.code_type) 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
516 if py_input_type is None: 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
517 raise ValueError(f"Unknown code_type associated with ObjectCode: {object_code.code_type}")
519 if self._use_nvjitlink: 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
520 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
521 c_nv_input_type = <cynvjitlink.nvJitLinkInputType><int>py_input_type 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
522 if isinstance(data, bytes): 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
523 c_data_ptr = <const char*>(<bytes>data) 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
524 c_data_size = len(data) 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
525 with nogil: 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
526 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddData( 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
527 c_nvjitlink_h, c_nv_input_type, <const void*>c_data_ptr, c_data_size, c_name_ptr))
528 elif isinstance(data, str):
529 file_bytes = data.encode()
530 c_file_ptr = <const char*>file_bytes
531 with nogil:
532 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddFile(
533 c_nvjitlink_h, c_nv_input_type, c_file_ptr))
534 else:
535 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
536 else:
537 c_culink_state = as_cu(self._culink_handle)
538 c_drv_input_type = <cydriver.CUjitInputType><int>py_input_type
539 try:
540 if isinstance(data, bytes):
541 c_data_ptr = <const char*>(<bytes>data)
542 c_data_size = len(data)
543 with nogil:
544 HANDLE_RETURN(cydriver.cuLinkAddData(
545 c_culink_state, c_drv_input_type, <void*>c_data_ptr, c_data_size, c_name_ptr,
546 0, NULL, NULL))
547 elif isinstance(data, str):
548 file_bytes = data.encode()
549 c_file_ptr = <const char*>file_bytes
550 with nogil:
551 HANDLE_RETURN(cydriver.cuLinkAddFile(
552 c_culink_state, c_drv_input_type, c_file_ptr, 0, NULL, NULL))
553 else:
554 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
555 except CUDAError as e:
556 Linker_annotate_error_log(self, e)
557 raise
560cdef inline object Linker_link(Linker self, str target_type):
561 """Complete linking and return the result as ObjectCode."""
562 if target_type not in ("cubin", "ptx"): 1NmLrABsCtDnkoEFGHIJwpxyqzKuOvcldaefghibj
563 raise ValueError(f"Unsupported target type: {target_type}") 1O
565 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
566 cdef cydriver.CUlinkState c_culink_state
567 cdef size_t c_output_size = 0 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
568 cdef char* c_code_ptr
569 cdef void* c_cubin_out = NULL 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
571 if self._use_nvjitlink: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
572 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
573 with nogil: 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
574 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkComplete(c_nvjitlink_h)) 1NmLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
575 if target_type == "cubin": 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
576 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
577 cynvjitlink.nvJitLinkGetLinkedCubinSize(c_nvjitlink_h, &c_output_size)) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
578 code = bytearray(c_output_size) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
579 c_code_ptr = <char*>(<bytearray>code) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
580 with nogil: 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
581 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
582 cynvjitlink.nvJitLinkGetLinkedCubin(c_nvjitlink_h, c_code_ptr)) 1mLrABsCtDnkoEFGHIJwpxyqzKuvldaefghibj
583 else:
584 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1c
585 cynvjitlink.nvJitLinkGetLinkedPtxSize(c_nvjitlink_h, &c_output_size)) 1c
586 code = bytearray(c_output_size) 1c
587 c_code_ptr = <char*>(<bytearray>code) 1c
588 with nogil: 1c
589 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1c
590 cynvjitlink.nvJitLinkGetLinkedPtx(c_nvjitlink_h, c_code_ptr)) 1c
591 else:
592 c_culink_state = as_cu(self._culink_handle)
593 try:
594 with nogil:
595 HANDLE_RETURN(cydriver.cuLinkComplete(c_culink_state, &c_cubin_out, &c_output_size))
596 except CUDAError as e:
597 Linker_annotate_error_log(self, e)
598 raise
599 code = (<char*>c_cubin_out)[:c_output_size]
601 # Linking is complete; cache the decoded log strings and release
602 # the driver's raw bytearray buffers (no longer written to).
603 self._info_log = self.get_info_log() 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
604 self._error_log = self.get_error_log() 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
605 self._drv_log_bufs = None 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
607 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 1mLrABsCtDnkoEFGHIJwpxyqzKuvcldaefghibj
610cdef inline void Linker_annotate_error_log(Linker self, object e):
611 """Annotate a CUDAError with the driver linker error log."""
612 error_log = self.get_error_log()
613 if error_log:
614 e.args = (e.args[0] + f"\nLinker error log: {error_log}", *e.args[1:])
617# =============================================================================
618# Private implementation: module-level state and initialization
619# =============================================================================
621# TODO: revisit this treatment for py313t builds
622_driver = None # populated if nvJitLink cannot be used
623_driver_ver = None
624_inited = False
625_use_nvjitlink_backend = False # set by _decide_nvjitlink_or_driver()
627# Input type mappings populated by _lazy_init() with C-level enum ints.
628_nvjitlink_input_types = None
629_driver_input_types = None
632def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
633 # This condition is equivalent to testing for version >= 12.3
634 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
637# Note: this function is reused in the tests
638def _decide_nvjitlink_or_driver() -> bool:
639 """Return True if falling back to the cuLink* driver APIs."""
640 global _driver_ver, _driver, _use_nvjitlink_backend
641 if _driver_ver is not None: 1MRT
642 return not _use_nvjitlink_backend
644 _driver_ver = handle_return(driver.cuDriverGetVersion()) 1MRT
645 _driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10) 1MRT
647 warn_txt_common = (
648 "the driver APIs will be used instead, which do not support" 1MRT
649 " minor version compatibility or linking LTO IRs."
650 " For best results, consider upgrading to a recent version of"
651 )
653 nvjitlink_module = _optional_cuda_import( 1MRT
654 "cuda.bindings.nvjitlink",
655 probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load 1MRT
656 )
657 if nvjitlink_module is None: 1MR
658 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 1R
659 else:
660 from cuda.bindings._internal import nvjitlink
662 if _nvjitlink_has_version_symbol(nvjitlink):
663 _use_nvjitlink_backend = True
664 return False # Use nvjitlink
665 warn_txt = (
666 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
667 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
668 )
670 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 1R
671 _driver = driver 1R
672 return True 1R
675def _lazy_init():
676 global _inited, _nvjitlink_input_types, _driver_input_types
677 if _inited: 1MNmPuOvclSQdaefghibj
678 return 1MNmPuOvclSQdaefghibj
680 _decide_nvjitlink_or_driver()
681 if _use_nvjitlink_backend:
682 _nvjitlink_input_types = {
683 "ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX,
684 "cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN,
685 "fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN,
686 "ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR,
687 "object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT,
688 "library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY,
689 }
690 else:
691 _driver_input_types = {
692 "ptx": <int>cydriver.CU_JIT_INPUT_PTX,
693 "cubin": <int>cydriver.CU_JIT_INPUT_CUBIN,
694 "fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY,
695 "object": <int>cydriver.CU_JIT_INPUT_OBJECT,
696 "library": <int>cydriver.CU_JIT_INPUT_LIBRARY,
697 }
698 _inited = True