Coverage for cuda / core / experimental / _linker.py: 70%
292 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7import ctypes
8import sys
9import weakref
10from contextlib import contextmanager
11from dataclasses import dataclass
12from typing import TYPE_CHECKING, Union
13from warnings import warn
15if TYPE_CHECKING:
16 import cuda.bindings
18from cuda.core.experimental._device import Device
19from cuda.core.experimental._module import ObjectCode
20from cuda.core.experimental._utils.clear_error_support import assert_type
21from cuda.core.experimental._utils.cuda_utils import check_or_create_options, driver, handle_return, is_sequence
23# TODO: revisit this treatment for py313t builds
24_driver = None # populated if nvJitLink cannot be used
25_driver_input_types = None # populated if nvJitLink cannot be used
26_driver_ver = None
27_inited = False
28_nvjitlink = None # populated if nvJitLink can be used
29_nvjitlink_input_types = None # populated if nvJitLink cannot be used
32def _nvjitlink_has_version_symbol(inner_nvjitlink) -> bool:
33 # This condition is equivalent to testing for version >= 12.3
34 return bool(inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
37# Note: this function is reused in the tests
38def _decide_nvjitlink_or_driver() -> bool:
39 """Returns True if falling back to the cuLink* driver APIs."""
40 global _driver_ver, _driver, _nvjitlink
41 if _driver or _nvjitlink:
42 return _driver is not None
44 _driver_ver = handle_return(driver.cuDriverGetVersion())
45 _driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
47 warn_txt_common = (
48 "the driver APIs will be used instead, which do not support"
49 " minor version compatibility or linking LTO IRs."
50 " For best results, consider upgrading to a recent version of"
51 )
53 try:
54 import cuda.bindings.nvjitlink as _nvjitlink
55 except ModuleNotFoundError:
56 warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
57 else:
58 from cuda.bindings._internal import nvjitlink as inner_nvjitlink
60 try:
61 if _nvjitlink_has_version_symbol(inner_nvjitlink):
62 return False # Use nvjitlink
63 except RuntimeError:
64 warn_detail = "not available"
65 else:
66 warn_detail = "too old (<12.3)"
67 warn_txt = (
68 f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is {warn_detail}."
69 f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
70 )
71 _nvjitlink = None
73 warn(warn_txt, stacklevel=2, category=RuntimeWarning)
74 _driver = driver
75 return True
78def _lazy_init():
79 global _inited, _nvjitlink_input_types, _driver_input_types
80 if _inited:
81 return
83 _decide_nvjitlink_or_driver()
84 if _nvjitlink:
85 if _driver_ver > _nvjitlink.version():
86 # TODO: nvJitLink is not new enough, warn?
87 pass
88 _nvjitlink_input_types = {
89 "ptx": _nvjitlink.InputType.PTX,
90 "cubin": _nvjitlink.InputType.CUBIN,
91 "fatbin": _nvjitlink.InputType.FATBIN,
92 "ltoir": _nvjitlink.InputType.LTOIR,
93 "object": _nvjitlink.InputType.OBJECT,
94 "library": _nvjitlink.InputType.LIBRARY,
95 }
96 else:
97 _driver_input_types = {
98 "ptx": _driver.CUjitInputType.CU_JIT_INPUT_PTX,
99 "cubin": _driver.CUjitInputType.CU_JIT_INPUT_CUBIN,
100 "fatbin": _driver.CUjitInputType.CU_JIT_INPUT_FATBINARY,
101 "object": _driver.CUjitInputType.CU_JIT_INPUT_OBJECT,
102 "library": _driver.CUjitInputType.CU_JIT_INPUT_LIBRARY,
103 }
104 _inited = True
107@dataclass
108class LinkerOptions:
109 """Customizable :obj:`Linker` options.
111 Since the linker would choose to use nvJitLink or the driver APIs as the linking backed,
112 not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
113 or not installed, the driver APIs (cuLink) will be used instead.
115 Attributes
116 ----------
117 name : str, optional
118 Name of the linker. If the linking succeeds, the name is passed down to the generated `ObjectCode`.
119 arch : str, optional
120 Pass the SM architecture value, such as ``sm_<CC>`` (for generating CUBIN) or
121 ``compute_<CC>`` (for generating PTX). If not provided, the current device's architecture
122 will be used.
123 max_register_count : int, optional
124 Maximum register count.
125 time : bool, optional
126 Print timing information to the info log.
127 Default: False.
128 verbose : bool, optional
129 Print verbose messages to the info log.
130 Default: False.
131 link_time_optimization : bool, optional
132 Perform link time optimization.
133 Default: False.
134 ptx : bool, optional
135 Emit PTX after linking instead of CUBIN; only supported with ``link_time_optimization=True``.
136 Default: False.
137 optimization_level : int, optional
138 Set optimization level. Only 0 and 3 are accepted.
139 debug : bool, optional
140 Generate debug information.
141 Default: False.
142 lineinfo : bool, optional
143 Generate line information.
144 Default: False.
145 ftz : bool, optional
146 Flush denormal values to zero.
147 Default: False.
148 prec_div : bool, optional
149 Use precise division.
150 Default: True.
151 prec_sqrt : bool, optional
152 Use precise square root.
153 Default: True.
154 fma : bool, optional
155 Use fast multiply-add.
156 Default: True.
157 kernels_used : [Union[str, tuple[str], list[str]]], optional
158 Pass a kernel or sequence of kernels that are used; any not in the list can be removed.
159 variables_used : [Union[str, tuple[str], list[str]]], optional
160 Pass a variable or sequence of variables that are used; any not in the list can be removed.
161 optimize_unused_variables : bool, optional
162 Assume that if a variable is not referenced in device code, it can be removed.
163 Default: False.
164 ptxas_options : [Union[str, tuple[str], list[str]]], optional
165 Pass options to PTXAS.
166 split_compile : int, optional
167 Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split
168 compilation (default).
169 Default: 1.
170 split_compile_extended : int, optional
171 A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value.
172 Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This
173 option can potentially impact performance of the compiled binary.
174 Default: 1.
175 no_cache : bool, optional
176 Do not cache the intermediate steps of nvJitLink.
177 Default: False.
178 """
180 name: str | None = "<default linker>"
181 arch: str | None = None
182 max_register_count: int | None = None
183 time: bool | None = None
184 verbose: bool | None = None
185 link_time_optimization: bool | None = None
186 ptx: bool | None = None
187 optimization_level: int | None = None
188 debug: bool | None = None
189 lineinfo: bool | None = None
190 ftz: bool | None = None
191 prec_div: bool | None = None
192 prec_sqrt: bool | None = None
193 fma: bool | None = None
194 kernels_used: Union[str, tuple[str], list[str]] | None = None
195 variables_used: Union[str, tuple[str], list[str]] | None = None
196 optimize_unused_variables: bool | None = None
197 ptxas_options: Union[str, tuple[str], list[str]] | None = None
198 split_compile: int | None = None
199 split_compile_extended: int | None = None
200 no_cache: bool | None = None
202 def __post_init__(self):
203 _lazy_init()
204 self._name = self.name.encode()
205 self.formatted_options = []
206 if _nvjitlink:
207 self._init_nvjitlink()
208 else:
209 self._init_driver()
211 def _init_nvjitlink(self):
212 if self.arch is not None:
213 self.formatted_options.append(f"-arch={self.arch}")
214 else:
215 self.formatted_options.append("-arch=sm_" + "".join(f"{i}" for i in Device().compute_capability))
216 if self.max_register_count is not None:
217 self.formatted_options.append(f"-maxrregcount={self.max_register_count}")
218 if self.time is not None:
219 self.formatted_options.append("-time")
220 if self.verbose:
221 self.formatted_options.append("-verbose")
222 if self.link_time_optimization:
223 self.formatted_options.append("-lto")
224 if self.ptx:
225 self.formatted_options.append("-ptx")
226 if self.optimization_level is not None:
227 self.formatted_options.append(f"-O{self.optimization_level}")
228 if self.debug:
229 self.formatted_options.append("-g")
230 if self.lineinfo:
231 self.formatted_options.append("-lineinfo")
232 if self.ftz is not None:
233 self.formatted_options.append(f"-ftz={'true' if self.ftz else 'false'}")
234 if self.prec_div is not None:
235 self.formatted_options.append(f"-prec-div={'true' if self.prec_div else 'false'}")
236 if self.prec_sqrt is not None:
237 self.formatted_options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}")
238 if self.fma is not None:
239 self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}")
240 if self.kernels_used is not None:
241 if isinstance(self.kernels_used, str):
242 self.formatted_options.append(f"-kernels-used={self.kernels_used}")
243 elif isinstance(self.kernels_used, list):
244 for kernel in self.kernels_used:
245 self.formatted_options.append(f"-kernels-used={kernel}")
246 if self.variables_used is not None:
247 if isinstance(self.variables_used, str):
248 self.formatted_options.append(f"-variables-used={self.variables_used}")
249 elif isinstance(self.variables_used, list):
250 for variable in self.variables_used:
251 self.formatted_options.append(f"-variables-used={variable}")
252 if self.optimize_unused_variables is not None:
253 self.formatted_options.append("-optimize-unused-variables")
254 if self.ptxas_options is not None:
255 if isinstance(self.ptxas_options, str):
256 self.formatted_options.append(f"-Xptxas={self.ptxas_options}")
257 elif is_sequence(self.ptxas_options):
258 for opt in self.ptxas_options:
259 self.formatted_options.append(f"-Xptxas={opt}")
260 if self.split_compile is not None:
261 self.formatted_options.append(f"-split-compile={self.split_compile}")
262 if self.split_compile_extended is not None:
263 self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}")
264 if self.no_cache is True:
265 self.formatted_options.append("-no-cache")
267 def _init_driver(self):
268 self.option_keys = []
269 # allocate 4 KiB each for info/error logs
270 size = 4194304
271 self.formatted_options.extend((bytearray(size), size, bytearray(size), size))
272 self.option_keys.extend(
273 (
274 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
275 _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
276 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
277 _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
278 )
279 )
281 if self.arch is not None:
282 arch = self.arch.split("_")[-1].upper()
283 self.formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
284 self.option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
285 if self.max_register_count is not None:
286 self.formatted_options.append(self.max_register_count)
287 self.option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
288 if self.time is not None:
289 raise ValueError("time option is not supported by the driver API")
290 if self.verbose:
291 self.formatted_options.append(1)
292 self.option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
293 if self.link_time_optimization:
294 self.formatted_options.append(1)
295 self.option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
296 if self.ptx:
297 raise ValueError("ptx option is not supported by the driver API")
298 if self.optimization_level is not None:
299 self.formatted_options.append(self.optimization_level)
300 self.option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
301 if self.debug:
302 self.formatted_options.append(1)
303 self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
304 if self.lineinfo:
305 self.formatted_options.append(1)
306 self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
307 if self.ftz is not None:
308 warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
309 if self.prec_div is not None:
310 warn("prec_div option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
311 if self.prec_sqrt is not None:
312 warn("prec_sqrt option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
313 if self.fma is not None:
314 warn("fma options is deprecated in the driver API", DeprecationWarning, stacklevel=3)
315 if self.kernels_used is not None:
316 warn("kernels_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
317 if self.variables_used is not None:
318 warn("variables_used is deprecated in the driver API", DeprecationWarning, stacklevel=3)
319 if self.optimize_unused_variables is not None:
320 warn("optimize_unused_variables is deprecated in the driver API", DeprecationWarning, stacklevel=3)
321 if self.ptxas_options is not None:
322 raise ValueError("ptxas_options option is not supported by the driver API")
323 if self.split_compile is not None:
324 raise ValueError("split_compile option is not supported by the driver API")
325 if self.split_compile_extended is not None:
326 raise ValueError("split_compile_extended option is not supported by the driver API")
327 if self.no_cache is True:
328 self.formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
329 self.option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
332# This needs to be a free function not a method, as it's disallowed by contextmanager.
333@contextmanager
334def _exception_manager(self):
335 """
336 A helper function to improve the error message of exceptions raised by the linker backend.
337 """
338 try:
339 yield
340 except Exception as e:
341 error_log = ""
342 if hasattr(self, "_mnff"):
343 # our constructor could raise, in which case there's no handle available
344 error_log = self.get_error_log()
345 # Starting Python 3.11 we could also use Exception.add_note() for the same purpose, but
346 # unfortunately we are still supporting Python 3.10...
347 # Here we rely on both CUDAError and nvJitLinkError have the error string placed in .args[0].
348 e.args = (e.args[0] + (f"\nLinker error log: {error_log}" if error_log else ""), *e.args[1:])
349 raise e
352nvJitLinkHandleT = int
353LinkerHandleT = Union[nvJitLinkHandleT, "cuda.bindings.driver.CUlinkState"]
356class Linker:
357 """Represent a linking machinery to link one or multiple object codes into
358 :obj:`~cuda.core.experimental._module.ObjectCode` with the specified options.
360 This object provides a unified interface to multiple underlying
361 linker libraries (such as nvJitLink or cuLink* from CUDA driver).
363 Parameters
364 ----------
365 object_codes : ObjectCode
366 One or more ObjectCode objects to be linked.
367 options : LinkerOptions, optional
368 Options for the linker. If not provided, default options will be used.
369 """
371 class _MembersNeededForFinalize:
372 __slots__ = ("handle", "use_nvjitlink", "const_char_keep_alive")
374 def __init__(self, program_obj, handle, use_nvjitlink):
375 self.handle = handle
376 self.use_nvjitlink = use_nvjitlink
377 self.const_char_keep_alive = []
378 weakref.finalize(program_obj, self.close)
380 def close(self):
381 if self.handle is not None:
382 if self.use_nvjitlink:
383 _nvjitlink.destroy(self.handle)
384 else:
385 handle_return(_driver.cuLinkDestroy(self.handle))
386 self.handle = None
388 __slots__ = ("__weakref__", "_mnff", "_options")
390 def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
391 if len(object_codes) == 0:
392 raise ValueError("At least one ObjectCode object must be provided")
394 self._options = options = check_or_create_options(LinkerOptions, options, "Linker options")
395 with _exception_manager(self):
396 if _nvjitlink:
397 handle = _nvjitlink.create(len(options.formatted_options), options.formatted_options)
398 use_nvjitlink = True
399 else:
400 handle = handle_return(
401 _driver.cuLinkCreate(len(options.formatted_options), options.option_keys, options.formatted_options)
402 )
403 use_nvjitlink = False
404 self._mnff = Linker._MembersNeededForFinalize(self, handle, use_nvjitlink)
406 for code in object_codes:
407 assert_type(code, ObjectCode)
408 self._add_code_object(code)
410 def _add_code_object(self, object_code: ObjectCode):
411 data = object_code._module
412 with _exception_manager(self):
413 name_str = f"{object_code.name}"
414 if _nvjitlink and isinstance(data, bytes):
415 _nvjitlink.add_data(
416 self._mnff.handle,
417 self._input_type_from_code_type(object_code._code_type),
418 data,
419 len(data),
420 name_str,
421 )
422 elif _nvjitlink and isinstance(data, str):
423 _nvjitlink.add_file(
424 self._mnff.handle,
425 self._input_type_from_code_type(object_code._code_type),
426 data,
427 )
428 elif (not _nvjitlink) and isinstance(data, bytes):
429 name_bytes = name_str.encode()
430 handle_return(
431 _driver.cuLinkAddData(
432 self._mnff.handle,
433 self._input_type_from_code_type(object_code._code_type),
434 data,
435 len(data),
436 name_bytes,
437 0,
438 None,
439 None,
440 )
441 )
442 self._mnff.const_char_keep_alive.append(name_bytes)
443 elif (not _nvjitlink) and isinstance(data, str):
444 name_bytes = name_str.encode()
445 handle_return(
446 _driver.cuLinkAddFile(
447 self._mnff.handle,
448 self._input_type_from_code_type(object_code._code_type),
449 data.encode(),
450 0,
451 None,
452 None,
453 )
454 )
455 self._mnff.const_char_keep_alive.append(name_bytes)
456 else:
457 raise TypeError(f"Expected bytes or str, but got {type(data).__name__}")
459 def link(self, target_type) -> ObjectCode:
460 """
461 Links the provided object codes into a single output of the specified target type.
463 Parameters
464 ----------
465 target_type : str
466 The type of the target output. Must be either "cubin" or "ptx".
468 Returns
469 -------
470 ObjectCode
471 The linked object code of the specified target type.
473 Note
474 ------
475 See nvrtc compiler options documnetation to ensure the input object codes are
476 correctly compiled for linking.
477 """
478 if target_type not in ("cubin", "ptx"):
479 raise ValueError(f"Unsupported target type: {target_type}")
480 with _exception_manager(self):
481 if _nvjitlink:
482 _nvjitlink.complete(self._mnff.handle)
483 if target_type == "cubin":
484 get_size = _nvjitlink.get_linked_cubin_size
485 get_code = _nvjitlink.get_linked_cubin
486 else:
487 get_size = _nvjitlink.get_linked_ptx_size
488 get_code = _nvjitlink.get_linked_ptx
489 size = get_size(self._mnff.handle)
490 code = bytearray(size)
491 get_code(self._mnff.handle, code)
492 else:
493 addr, size = handle_return(_driver.cuLinkComplete(self._mnff.handle))
494 code = (ctypes.c_char * size).from_address(addr)
496 return ObjectCode._init(bytes(code), target_type, name=self._options.name)
498 def get_error_log(self) -> str:
499 """Get the error log generated by the linker.
501 Returns
502 -------
503 str
504 The error log.
505 """
506 if _nvjitlink:
507 log_size = _nvjitlink.get_error_log_size(self._mnff.handle)
508 log = bytearray(log_size)
509 _nvjitlink.get_error_log(self._mnff.handle, log)
510 else:
511 log = self._options.formatted_options[2]
512 return log.decode("utf-8", errors="backslashreplace")
514 def get_info_log(self) -> str:
515 """Get the info log generated by the linker.
517 Returns
518 -------
519 str
520 The info log.
521 """
522 if _nvjitlink:
523 log_size = _nvjitlink.get_info_log_size(self._mnff.handle)
524 log = bytearray(log_size)
525 _nvjitlink.get_info_log(self._mnff.handle, log)
526 else:
527 log = self._options.formatted_options[0]
528 return log.decode("utf-8", errors="backslashreplace")
530 def _input_type_from_code_type(self, code_type: str):
531 # this list is based on the supported values for code_type in the ObjectCode class definition.
532 # nvJitLink/driver support other options for input type
533 input_type = _nvjitlink_input_types.get(code_type) if _nvjitlink else _driver_input_types.get(code_type)
535 if input_type is None:
536 raise ValueError(f"Unknown code_type associated with ObjectCode: {code_type}")
537 return input_type
539 @property
540 def handle(self) -> LinkerHandleT:
541 """Return the underlying handle object.
543 .. note::
545 The type of the returned object depends on the backend.
547 .. caution::
549 This handle is a Python object. To get the memory address of the underlying C
550 handle, call ``int(Linker.handle)``.
551 """
552 return self._mnff.handle
554 @property
555 def backend(self) -> str:
556 """Return this Linker instance's underlying backend."""
557 return "nvJitLink" if self._mnff.use_nvjitlink else "driver"
559 def close(self):
560 """Destroy this linker."""
561 self._mnff.close()