Coverage for cuda / core / _linker.pyx: 63.54%
362 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4"""Linking machinery for combining object codes.
6This module provides :class:`Linker` for linking one or more
7:class:`~cuda.core.ObjectCode` objects, with :class:`LinkerOptions` for
8configuration.
9"""
11from __future__ import annotations
13from cpython.bytearray cimport PyByteArray_AS_STRING
14from libc.stdint cimport intptr_t, uint32_t
15from libcpp.vector cimport vector
16from cuda.bindings cimport cydriver
17from cuda.bindings cimport cynvjitlink
19from ._resource_handles cimport (
20 as_cu,
21 as_py,
22 create_culink_handle,
23 create_nvjitlink_handle,
24)
25from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, HANDLE_RETURN_NVJITLINK
27import sys
28from dataclasses import dataclass
29from typing import Union
30from warnings import warn
32from cuda.pathfinder import optional_cuda_import
33from cuda.core._device import Device
34from cuda.core._module import ObjectCode
35from cuda.core._utils.clear_error_support import assert_type
36from cuda.core._utils.cuda_utils import (
37 CUDAError,
38 check_or_create_options,
39 driver,
40 handle_return,
41 is_sequence,
42)
44ctypedef const char* const_char_ptr
45ctypedef void* void_ptr
47__all__ = ["Linker", "LinkerOptions"]
49LinkerHandleT = Union["cuda.bindings.nvjitlink.nvJitLinkHandle", "cuda.bindings.driver.CUlinkState"]
52# =============================================================================
53# Principal class
54# =============================================================================
56cdef class Linker:
57 """Represent a linking machinery to link one or more object codes into
58 :class:`~cuda.core.ObjectCode`.
60 This object provides a unified interface to multiple underlying
61 linker libraries (such as nvJitLink or cuLink* from the CUDA driver).
63 Parameters
64 ----------
65 object_codes : :class:`~cuda.core.ObjectCode`
66 One or more ObjectCode objects to be linked.
67 options : :class:`LinkerOptions`, optional
68 Options for the linker. If not provided, default options will be used.
69 """
71 def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
72 Linker_init(self, object_codes, options) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
74 def link(self, target_type) -> ObjectCode:
75 """Link the provided object codes into a single output of the specified target type.
77 Parameters
78 ----------
79 target_type : str
80 The type of the target output. Must be either "cubin" or "ptx".
82 Returns
83 -------
84 :class:`~cuda.core.ObjectCode`
85 The linked object code of the specified target type.
87 .. note::
89 Ensure that input object codes were compiled with appropriate
90 flags for linking (e.g., relocatable device code enabled).
91 """
92 return Linker_link(self, target_type) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
94 def get_error_log(self) -> str:
95 """Get the error log generated by the linker.
97 Returns
98 -------
99 str
100 The error log.
101 """
102 # After link(), the decoded log is cached here.
103 if self._error_log is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
104 return self._error_log
105 cdef cynvjitlink.nvJitLinkHandle c_h
106 cdef size_t c_log_size = 0 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
107 cdef char* c_log_ptr
108 if self._use_nvjitlink: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
109 c_h = as_cu(self._nvjitlink_handle) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
110 cynvjitlink.nvJitLinkGetErrorLogSize(c_h, &c_log_size) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
111 log = bytearray(c_log_size) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
112 if c_log_size > 0: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
113 c_log_ptr = <char*>(<bytearray>log) 1M
114 cynvjitlink.nvJitLinkGetErrorLog(c_h, c_log_ptr) 1M
115 return log.decode("utf-8", errors="backslashreplace") 1LMmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
116 else:
117 return (<bytearray>self._drv_log_bufs[2]).decode(
118 "utf-8", errors="backslashreplace").rstrip('\x00')
120 def get_info_log(self) -> str:
121 """Get the info log generated by the linker.
123 Returns
124 -------
125 str
126 The info log.
127 """
128 # After link(), the decoded log is cached here.
129 if self._info_log is not None: 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
130 return self._info_log 1m
131 cdef cynvjitlink.nvJitLinkHandle c_h
132 cdef size_t c_log_size = 0 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
133 cdef char* c_log_ptr
134 if self._use_nvjitlink: 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
135 c_h = as_cu(self._nvjitlink_handle) 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
136 cynvjitlink.nvJitLinkGetInfoLogSize(c_h, &c_log_size) 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
137 log = bytearray(c_log_size) 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
138 if c_log_size > 0: 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
139 c_log_ptr = <char*>(<bytearray>log) 1qrsnjkab
140 cynvjitlink.nvJitLinkGetInfoLog(c_h, c_log_ptr) 1qrsnjkab
141 return log.decode("utf-8", errors="backslashreplace") 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
142 else:
143 return (<bytearray>self._drv_log_bufs[0]).decode(
144 "utf-8", errors="backslashreplace").rstrip('\x00')
146 def close(self):
147 """Destroy this linker."""
148 if self._use_nvjitlink: 1cadefghbi
149 self._nvjitlink_handle.reset() 1cadefghbi
150 else:
151 self._culink_handle.reset()
153 @property
154 def handle(self) -> LinkerHandleT:
155 """Return the underlying handle object.
157 .. note::
159 The type of the returned object depends on the backend.
161 .. caution::
163 This handle is a Python object. To get the memory address of the underlying C
164 handle, call ``int(Linker.handle)``.
165 """
166 if self._use_nvjitlink:
167 return as_py(self._nvjitlink_handle)
168 else:
169 return as_py(self._culink_handle)
171 @property
172 def backend(self) -> str:
173 """Return this Linker instance's underlying backend."""
174 return "nvJitLink" if self._use_nvjitlink else "driver" 1KqzArBsCnjkDEFGHIvowxpyJcadefghbi
177# =============================================================================
178# Supporting classes
179# =============================================================================
181@dataclass
182class LinkerOptions:
183 """Customizable options for configuring :class:`Linker`.
185 Since the linker may choose to use nvJitLink or the driver APIs as the linking backend,
186 not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
187 or not installed, the driver APIs (cuLink) will be used instead.
189 Attributes
190 ----------
191 name : str, optional
192 Name of the linker. If the linking succeeds, the name is passed down to the generated `ObjectCode`.
193 arch : str, optional
194 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or
195 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture
196 will be used.
197 max_register_count : int, optional
198 Maximum register count.
199 time : bool, optional
200 Print timing information to the info log.
201 Default: False.
202 verbose : bool, optional
203 Print verbose messages to the info log.
204 Default: False.
205 link_time_optimization : bool, optional
206 Perform link time optimization.
207 Default: False.
208 ptx : bool, optional
209 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``.
210 Default: False.
211 optimization_level : int, optional
212 Set optimization level. Only 0 and 3 are accepted.
213 debug : bool, optional
214 Generate debug information.
215 Default: False.
216 lineinfo : bool, optional
217 Generate line information.
218 Default: False.
219 ftz : bool, optional
220 Flush denormal values to zero.
221 Default: False.
222 prec_div : bool, optional
223 Use precise division.
224 Default: True.
225 prec_sqrt : bool, optional
226 Use precise square root.
227 Default: True.
228 fma : bool, optional
229 Use fast multiply-add.
230 Default: True.
231 kernels_used : [str | tuple[str] | list[str]], optional
232 Pass a kernel or sequence of kernels that are used; any not in the list can be removed.
233 variables_used : [str | tuple[str] | list[str]], optional
234 Pass a variable or sequence of variables that are used; any not in the list can be removed.
235 optimize_unused_variables : bool, optional
236 Assume that if a variable is not referenced in device code, it can be removed.
237 Default: False.
238 ptxas_options : [str | tuple[str] | list[str]], optional
239 Pass options to PTXAS.
240 split_compile : int, optional
241 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split
242 compilation (default).
243 Default: 1.
244 split_compile_extended : int, optional
245 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value.
246 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This
247 option can potentially impact performance of the compiled binary.
248 Default: 1.
249 no_cache : bool, optional
250 Do not cache the intermediate steps of nvJitLink.
251 Default: False.
252 """
254 name: str | None = "<default linker>"
255 arch: str | None = None
256 max_register_count: int | None = None
257 time: bool | None = None
258 verbose: bool | None = None
259 link_time_optimization: bool | None = None
260 ptx: bool | None = None
261 optimization_level: int | None = None
262 debug: bool | None = None
263 lineinfo: bool | None = None
264 ftz: bool | None = None
265 prec_div: bool | None = None
266 prec_sqrt: bool | None = None
267 fma: bool | None = None
268 kernels_used: str | tuple[str] | list[str] | None = None
269 variables_used: str | tuple[str] | list[str] | None = None
270 optimize_unused_variables: bool | None = None
271 ptxas_options: str | tuple[str] | list[str] | None = None
272 split_compile: int | None = None
273 split_compile_extended: int | None = None
274 no_cache: bool | None = None
276 def __post_init__(self):
277 _lazy_init() 1LMmOtNulRPcadefghbi
278 self._name = self.name.encode() 1LMmOtNulRPcadefghbi
280 def _prepare_nvjitlink_options(self, as_bytes: bool = False) -> list[bytes] | list[str]:
281 options = [] 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
283 if self.arch is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
284 options.append(f"-arch={self.arch}") 1MmqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
285 else:
286 options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability)) 1K
287 if self.max_register_count is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
288 options.append(f"-maxrregcount={self.max_register_count}") 1zPc
289 if self.time is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
290 options.append("-time") 1sb
291 if self.verbose: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
292 options.append("-verbose") 1q
293 if self.link_time_optimization: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
294 options.append("-lto") 1l
295 if self.ptx: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
296 options.append("-ptx") 1Ol
297 if self.optimization_level is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
298 options.append(f"-O{self.optimization_level}") 1A
299 if self.debug: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
300 options.append("-g") 1rPa
301 if self.lineinfo: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
302 options.append("-lineinfo") 1Bd
303 if self.ftz is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
304 options.append(f"-ftz={'true' if self.ftz else 'false'}") 1FPe
305 if self.prec_div is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
306 options.append(f"-prec-div={'true' if self.prec_div else 'false'}") 1Gf
307 if self.prec_sqrt is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
308 options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}") 1Hg
309 if self.fma is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
310 options.append(f"-fma={'true' if self.fma else 'false'}") 1Ih
311 if self.kernels_used is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
312 if isinstance(self.kernels_used, str): 1vow
313 options.append(f"-kernels-used={self.kernels_used}") 1v
314 elif isinstance(self.kernels_used, list): 1ow
315 for kernel in self.kernels_used: 1o
316 options.append(f"-kernels-used={kernel}") 1o
317 if self.variables_used is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
318 if isinstance(self.variables_used, str): 1xpy
319 options.append(f"-variables-used={self.variables_used}") 1x
320 elif isinstance(self.variables_used, list): 1py
321 for variable in self.variables_used: 1p
322 options.append(f"-variables-used={variable}") 1p
323 if self.optimize_unused_variables is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
324 options.append("-optimize-unused-variables") 1C
325 if self.ptxas_options is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
326 if isinstance(self.ptxas_options, str): 1njk
327 options.append(f"-Xptxas={self.ptxas_options}") 1n
328 elif is_sequence(self.ptxas_options): 1jk
329 for opt in self.ptxas_options: 1jk
330 options.append(f"-Xptxas={opt}") 1jk
331 if self.split_compile is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
332 options.append(f"-split-compile={self.split_compile}") 1Di
333 if self.split_compile_extended is not None: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
334 options.append(f"-split-compile-extended={self.split_compile_extended}") 1E
335 if self.no_cache is True: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
336 options.append("-no-cache") 1J
338 if as_bytes: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
339 return [o.encode() for o in options] 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulPcadefghbi
340 else:
341 return options
343 def _prepare_driver_options(self) -> tuple[list, list]:
344 formatted_options = []
345 option_keys = []
347 # allocate a fixed-sized buffer for each info/error log
348 size = 4194304
349 formatted_options.extend((bytearray(size), size, bytearray(size), size))
350 option_keys.extend(
351 (
352 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
353 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
354 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
355 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
356 )
357 )
359 if self.arch is not None:
360 arch = self.arch.split("_")[-1].upper()
361 formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
362 option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
363 if self.max_register_count is not None:
364 formatted_options.append(self.max_register_count)
365 option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
366 if self.time is not None:
367 raise ValueError("time option is not supported by the driver API")
368 if self.verbose:
369 formatted_options.append(1)
370 option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
371 if self.link_time_optimization:
372 formatted_options.append(1)
373 option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
374 if self.ptx:
375 raise ValueError("ptx option is not supported by the driver API")
376 if self.optimization_level is not None:
377 formatted_options.append(self.optimization_level)
378 option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
379 if self.debug:
380 formatted_options.append(1)
381 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
382 if self.lineinfo:
383 formatted_options.append(1)
384 option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
385 if self.ftz is not None:
386 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
387 if self.prec_div is not None:
388 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
389 if self.prec_sqrt is not None:
390 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
391 if self.fma is not None:
392 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3)
393 if self.kernels_used is not None:
394 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
395 if self.variables_used is not None:
396 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
397 if self.optimize_unused_variables is not None:
398 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3)
399 if self.ptxas_options is not None:
400 raise ValueError("ptxas_options option is not supported by the driver API")
401 if self.split_compile is not None:
402 raise ValueError("split_compile option is not supported by the driver API")
403 if self.split_compile_extended is not None:
404 raise ValueError("split_compile_extended option is not supported by the driver API")
405 if self.no_cache is True:
406 formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
407 option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
409 return formatted_options, option_keys
411 def as_bytes(self, backend: str = "nvjitlink") -> list[bytes]:
412 """Convert linker options to bytes format for the nvjitlink backend.
414 Parameters
415 ----------
416 backend : str, optional
417 The linker backend. Only "nvjitlink" is supported. Default is "nvjitlink".
419 Returns
420 -------
421 list[bytes]
422 List of option strings encoded as bytes.
424 Raises
425 ------
426 ValueError
427 If an unsupported backend is specified.
428 RuntimeError
429 If nvJitLink backend is not available.
430 """
431 backend = backend.lower() 1RP
432 if backend != "nvjitlink": 1RP
433 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") 1R
434 if not _use_nvjitlink_backend: 1P
435 raise RuntimeError("nvJitLink backend is not available")
436 return self._prepare_nvjitlink_options(as_bytes=True) 1P
439# =============================================================================
440# Private implementation: cdef inline helpers
441# =============================================================================
443cdef inline int Linker_init(Linker self, tuple object_codes, object options) except -1:
444 """Initialize a Linker instance."""
445 if len(object_codes) == 0: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
446 raise ValueError("At least one ObjectCode object must be provided")
448 cdef cynvjitlink.nvJitLinkHandle c_raw_nvjitlink
449 cdef cydriver.CUlinkState c_raw_culink
450 cdef Py_ssize_t c_num_opts, i
451 cdef vector[const_char_ptr] c_str_opts
452 cdef vector[cydriver.CUjit_option] c_jit_keys
453 cdef vector[void_ptr] c_jit_values
455 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
457 if _use_nvjitlink_backend: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
458 self._use_nvjitlink = True 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
459 options_bytes = options._prepare_nvjitlink_options(as_bytes=True) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
460 c_num_opts = len(options_bytes) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
461 c_str_opts.resize(c_num_opts) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
462 for i in range(c_num_opts): 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
463 c_str_opts[i] = <const char*>(<bytes>options_bytes[i]) 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
464 with nogil: 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
465 HANDLE_RETURN_NVJITLINK(NULL, cynvjitlink.nvJitLinkCreate( 1MmKqzArBsCnjkDEFGHIvowxpyJOtNulcadefghbi
466 &c_raw_nvjitlink, <uint32_t>c_num_opts, c_str_opts.data()))
467 self._nvjitlink_handle = create_nvjitlink_handle(c_raw_nvjitlink) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
468 else:
469 self._use_nvjitlink = False
470 formatted_options, option_keys = options._prepare_driver_options()
471 # Keep the formatted_options list alive: it contains bytearrays that
472 # the driver writes into via raw pointers during linking operations.
473 self._drv_log_bufs = formatted_options
474 c_num_opts = len(option_keys)
475 c_jit_keys.resize(c_num_opts)
476 c_jit_values.resize(c_num_opts)
477 for i in range(c_num_opts):
478 c_jit_keys[i] = <cydriver.CUjit_option><int>option_keys[i]
479 val = formatted_options[i]
480 if isinstance(val, bytearray):
481 c_jit_values[i] = <void*>PyByteArray_AS_STRING(val)
482 else:
483 c_jit_values[i] = <void*><intptr_t>int(val)
484 try:
485 with nogil:
486 HANDLE_RETURN(cydriver.cuLinkCreate(
487 <unsigned int>c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink))
488 except CUDAError as e:
489 Linker_annotate_error_log(self, e)
490 raise
491 self._culink_handle = create_culink_handle(c_raw_culink)
493 for code in object_codes: 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
494 assert_type(code, ObjectCode) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
495 Linker_add_code_object(self, code) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
496 return 0 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
499cdef inline void Linker_add_code_object(Linker self, object object_code) except *:
500 """Add a single ObjectCode to the linker."""
501 data = object_code.code 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
502 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
503 cdef cydriver.CUlinkState c_culink_state
504 cdef cynvjitlink.nvJitLinkInputType c_nv_input_type
505 cdef cydriver.CUjitInputType c_drv_input_type
506 cdef const char* c_data_ptr
507 cdef size_t c_data_size
508 cdef const char* c_name_ptr
509 cdef const char* c_file_ptr
511 name_bytes = f"{object_code.name}".encode() 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
512 c_name_ptr = <const char*>name_bytes 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
514 input_types = _nvjitlink_input_types if self._use_nvjitlink else _driver_input_types 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
515 py_input_type = input_types.get(object_code.code_type) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
516 if py_input_type is None: 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
517 raise ValueError(f"Unknown code_type associated with ObjectCode: {object_code.code_type}")
519 if self._use_nvjitlink: 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
520 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
521 c_nv_input_type = <cynvjitlink.nvJitLinkInputType><int>py_input_type 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
522 if isinstance(data, bytes): 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
523 c_data_ptr = <const char*>(<bytes>data) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
524 c_data_size = len(data) 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
525 with nogil: 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
526 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddData( 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
527 c_nvjitlink_h, c_nv_input_type, <const void*>c_data_ptr, c_data_size, c_name_ptr))
528 elif isinstance(data, str):
529 file_bytes = data.encode()
530 c_file_ptr = <const char*>file_bytes
531 with nogil:
532 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkAddFile(
533 c_nvjitlink_h, c_nv_input_type, c_file_ptr))
534 else:
535 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
536 else:
537 c_culink_state = as_cu(self._culink_handle)
538 c_drv_input_type = <cydriver.CUjitInputType><int>py_input_type
539 try:
540 if isinstance(data, bytes):
541 c_data_ptr = <const char*>(<bytes>data)
542 c_data_size = len(data)
543 with nogil:
544 HANDLE_RETURN(cydriver.cuLinkAddData(
545 c_culink_state, c_drv_input_type, <void*>c_data_ptr, c_data_size, c_name_ptr,
546 0, NULL, NULL))
547 elif isinstance(data, str):
548 file_bytes = data.encode()
549 c_file_ptr = <const char*>file_bytes
550 with nogil:
551 HANDLE_RETURN(cydriver.cuLinkAddFile(
552 c_culink_state, c_drv_input_type, c_file_ptr, 0, NULL, NULL))
553 else:
554 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
555 except CUDAError as e:
556 Linker_annotate_error_log(self, e)
557 raise
560cdef inline object Linker_link(Linker self, str target_type):
561 """Complete linking and return the result as ObjectCode."""
562 if target_type not in ("cubin", "ptx"): 1MmKqzArBsCnjkDEFGHIvowxpyJtNulcadefghbi
563 raise ValueError(f"Unsupported target type: {target_type}") 1N
565 cdef cynvjitlink.nvJitLinkHandle c_nvjitlink_h
566 cdef cydriver.CUlinkState c_culink_state
567 cdef size_t c_output_size = 0 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
568 cdef char* c_code_ptr
569 cdef void* c_cubin_out = NULL 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
571 if self._use_nvjitlink: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
572 c_nvjitlink_h = as_cu(self._nvjitlink_handle) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
573 with nogil: 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
574 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, cynvjitlink.nvJitLinkComplete(c_nvjitlink_h)) 1MmKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
575 if target_type == "cubin": 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
576 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi
577 cynvjitlink.nvJitLinkGetLinkedCubinSize(c_nvjitlink_h, &c_output_size)) 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi
578 code = bytearray(c_output_size) 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi
579 c_code_ptr = <char*>(<bytearray>code) 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi
580 with nogil: 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi
581 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi
582 cynvjitlink.nvJitLinkGetLinkedCubin(c_nvjitlink_h, c_code_ptr)) 1mKqzArBsCnjkDEFGHIvowxpyJtucadefghbi
583 else:
584 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1l
585 cynvjitlink.nvJitLinkGetLinkedPtxSize(c_nvjitlink_h, &c_output_size)) 1l
586 code = bytearray(c_output_size) 1l
587 c_code_ptr = <char*>(<bytearray>code) 1l
588 with nogil: 1l
589 HANDLE_RETURN_NVJITLINK(c_nvjitlink_h, 1l
590 cynvjitlink.nvJitLinkGetLinkedPtx(c_nvjitlink_h, c_code_ptr)) 1l
591 else:
592 c_culink_state = as_cu(self._culink_handle)
593 try:
594 with nogil:
595 HANDLE_RETURN(cydriver.cuLinkComplete(c_culink_state, &c_cubin_out, &c_output_size))
596 except CUDAError as e:
597 Linker_annotate_error_log(self, e)
598 raise
599 code = (<char*>c_cubin_out)[:c_output_size]
601 # Linking is complete; cache the decoded log strings and release
602 # the driver's raw bytearray buffers (no longer written to).
603 self._info_log = self.get_info_log() 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
604 self._error_log = self.get_error_log() 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
605 self._drv_log_bufs = None 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
607 return ObjectCode._init(bytes(code), target_type, name=self._options.name) 1mKqzArBsCnjkDEFGHIvowxpyJtulcadefghbi
610cdef inline void Linker_annotate_error_log(Linker self, object e):
611 """Annotate a CUDAError with the driver linker error log."""
612 error_log = self.get_error_log()
613 if error_log:
614 e.args = (e.args[0] + f"\nLinker error log: {error_log}", *e.args[1:])
617# =============================================================================
618# Private implementation: module-level state and initialization
619# =============================================================================
621# TODO: revisit this treatment for py313t builds
622_driver = None # populated if nvJitLink cannot be used
623_driver_ver = None
624_inited = False
625_use_nvjitlink_backend = False # set by _decide_nvjitlink_or_driver()
627# Input type mappings populated by _lazy_init() with C-level enum ints.
628_nvjitlink_input_types = None
629_driver_input_types = None
632def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
633 # This condition is equivalent to testing for version >= 12.3
634 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
637# Note: this function is reused in the tests
638def _decide_nvjitlink_or_driver() -> bool:
639 """Return True if falling back to the cuLink* driver APIs."""
640 global _driver_ver, _driver, _use_nvjitlink_backend
641 if _driver_ver is not None: 1LQS
642 return not _use_nvjitlink_backend
644 _driver_ver = handle_return(driver.cuDriverGetVersion()) 1LQS
645 _driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10) 1LQS
647 warn_txt_common = (
648 "the driver APIs will be used instead, which do not support" 1LQS
649 " minor version compatibility or linking LTO IRs."
650 " For best results, consider upgrading to a recent version of"
651 )
653 nvjitlink_module = optional_cuda_import( 1LQS
654 "cuda.bindings.nvjitlink",
655 probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load 1LQS
656 )
657 if nvjitlink_module is None: 1LQ
658 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." 1Q
659 else:
660 from cuda.bindings._internal import nvjitlink
662 if _nvjitlink_has_version_symbol(nvjitlink):
663 _use_nvjitlink_backend = True
664 return False # Use nvjitlink
665 warn_txt = (
666 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
667 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
668 )
670 warn(warn_txt, stacklevel=2, category=RuntimeWarning) 1Q
671 _driver = driver 1Q
672 return True 1Q
675def _lazy_init():
676 global _inited, _nvjitlink_input_types, _driver_input_types
677 if _inited: 1LMmOtNulRPcadefghbi
678 return 1LMmOtNulRPcadefghbi
680 _decide_nvjitlink_or_driver()
681 if _use_nvjitlink_backend:
682 _nvjitlink_input_types = {
683 "ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX,
684 "cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN,
685 "fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN,
686 "ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR,
687 "object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT,
688 "library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY,
689 }
690 else:
691 _driver_input_types = {
692 "ptx": <int>cydriver.CU_JIT_INPUT_PTX,
693 "cubin": <int>cydriver.CU_JIT_INPUT_CUBIN,
694 "fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY,
695 "object": <int>cydriver.CU_JIT_INPUT_OBJECT,
696 "library": <int>cydriver.CU_JIT_INPUT_LIBRARY,
697 }
698 _inited = True