Coverage for cuda / core / experimental / _utils / cuda_utils.pyx: 79%
180 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5import functools
6from functools import partial
7import importlib.metadata
8import multiprocessing
9import platform
10import warnings
11from collections import namedtuple
12from collections.abc import Sequence
13from contextlib import ExitStack
14from typing import Callable
16try:
17 from cuda.bindings import driver, nvrtc, runtime
18except ImportError:
19 from cuda import cuda as driver
20 from cuda import cudart as runtime
21 from cuda import nvrtc
23from cuda.core.experimental._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS
24from cuda.core.experimental._utils.runtime_cuda_error_explanations import RUNTIME_CUDA_ERROR_EXPLANATIONS
27class CUDAError(Exception):
28 pass
31class NVRTCError(CUDAError):
32 pass
35ComputeCapability = namedtuple("ComputeCapability", ("major", "minor"))
38def cast_to_3_tuple(label, cfg):
39 cfg_orig = cfg
40 if isinstance(cfg, int):
41 cfg = (cfg,)
42 else:
43 common = "must be an int, or a tuple with up to 3 ints"
44 if not isinstance(cfg, tuple):
45 raise ValueError(f"{label} {common} (got {type(cfg)})")
46 if len(cfg) > 3:
47 raise ValueError(f"{label} {common} (got tuple with length {len(cfg)})")
48 if any(not isinstance(val, int) for val in cfg):
49 raise ValueError(f"{label} {common} (got {cfg})")
50 if any(val < 1 for val in cfg):
51 plural_s = "" if len(cfg) == 1 else "s"
52 raise ValueError(f"{label} value{plural_s} must be >= 1 (got {cfg_orig})")
53 return cfg + (1,) * (3 - len(cfg))
56def _reduce_3_tuple(t: tuple):
57 return t[0] * t[1] * t[2]
60cdef int HANDLE_RETURN(supported_error_type err) except?-1 nogil:
61 if supported_error_type is cydriver.CUresult:
62 if err != cydriver.CUresult.CUDA_SUCCESS:
63 return _check_driver_error(err)
66cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess
67cdef object _NVRTC_SUCCESS = nvrtc.nvrtcResult.NVRTC_SUCCESS
70cpdef inline int _check_driver_error(cydriver.CUresult error) except?-1 nogil:
71 if error == cydriver.CUresult.CUDA_SUCCESS:
72 return 0
73 cdef const char* name
74 name_err = cydriver.cuGetErrorName(error, &name)
75 if name_err != cydriver.CUresult.CUDA_SUCCESS:
76 raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
77 with gil:
78 # TODO: consider lower this to Cython
79 expl = DRIVER_CU_RESULT_EXPLANATIONS.get(int(error))
80 if expl is not None:
81 raise CUDAError(f"{name.decode()}: {expl}")
82 cdef const char* desc
83 desc_err = cydriver.cuGetErrorString(error, &desc)
84 if desc_err != cydriver.CUresult.CUDA_SUCCESS:
85 raise CUDAError(f"{name.decode()}")
86 raise CUDAError(f"{name.decode()}: {desc.decode()}")
89cpdef inline int _check_runtime_error(error) except?-1:
90 if error == _RUNTIME_SUCCESS:
91 return 0
92 name_err, name = runtime.cudaGetErrorName(error)
93 if name_err != _RUNTIME_SUCCESS:
94 raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
95 name = name.decode()
96 expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int(error))
97 if expl is not None:
98 raise CUDAError(f"{name}: {expl}")
99 desc_err, desc = runtime.cudaGetErrorString(error)
100 if desc_err != _RUNTIME_SUCCESS:
101 raise CUDAError(f"{name}")
102 desc = desc.decode()
103 raise CUDAError(f"{name}: {desc}")
106cpdef inline int _check_nvrtc_error(error, handle=None) except?-1:
107 if error == _NVRTC_SUCCESS:
108 return 0
109 err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}"
110 if handle is not None:
111 _, logsize = nvrtc.nvrtcGetProgramLogSize(handle)
112 log = b" " * logsize
113 _ = nvrtc.nvrtcGetProgramLog(handle, log)
114 err += f", compilation log:\n\n{log.decode('utf-8', errors='backslashreplace')}"
115 raise NVRTCError(err)
118cdef inline int _check_error(error, handle=None) except?-1:
119 if isinstance(error, driver.CUresult):
120 return _check_driver_error(error)
121 elif isinstance(error, runtime.cudaError_t):
122 return _check_runtime_error(error)
123 elif isinstance(error, nvrtc.nvrtcResult):
124 return _check_nvrtc_error(error, handle=handle)
125 else:
126 raise RuntimeError(f"Unknown error type: {error}")
129def handle_return(tuple result, handle=None):
130 _check_error(result[0], handle=handle)
131 cdef int out_len = len(result)
132 if out_len == 1:
133 return
134 elif out_len == 2:
135 return result[1]
136 else:
137 return result[1:]
140cpdef check_or_create_options(type cls, options, str options_description="", bint keep_none=False):
141 """
142 Create the specified options dataclass from a dictionary of options or None.
143 """
144 if options is None:
145 if keep_none:
146 return options
147 return cls()
148 elif isinstance(options, cls):
149 return options
150 elif isinstance(options, dict):
151 return cls(**options)
152 else:
153 raise TypeError(
154 f"The {options_description} must be provided as an object "
155 f"of type {cls.__name__} or as a dict with valid {options_description}. "
156 f"The provided object is '{options}'."
157 )
160def _handle_boolean_option(option: bool) -> str:
161 """
162 Convert a boolean option to a string representation.
163 """
164 return "true" if bool(option) else "false"
167def precondition(checker: Callable[..., None], str what="") -> Callable:
168 """
169 A decorator that adds checks to ensure any preconditions are met.
171 Args:
172 checker: The function to call to check whether the preconditions are met. It has
173 the same signature as the wrapped function with the addition of the keyword argument `what`.
174 what: A string that is passed in to `checker` to provide context information.
176 Returns:
177 Callable: A decorator that creates the wrapping.
178 """
180 def outer(wrapped_function):
181 """
182 A decorator that actually wraps the function for checking preconditions.
183 """
185 @functools.wraps(wrapped_function)
186 def inner(*args, **kwargs):
187 """
188 Check preconditions and if they are met, call the wrapped function.
189 """
190 checker(*args, **kwargs, what=what)
191 result = wrapped_function(*args, **kwargs)
193 return result
195 return inner
197 return outer
200cdef cydriver.CUdevice get_device_from_ctx(
201 cydriver.CUcontext target_ctx, cydriver.CUcontext curr_ctx) except?cydriver.CU_DEVICE_INVALID nogil:
202 """Get device ID from the given ctx."""
203 cdef bint switch_context = (curr_ctx != target_ctx)
204 cdef cydriver.CUcontext ctx
205 cdef cydriver.CUdevice target_dev
206 with nogil:
207 if switch_context:
208 HANDLE_RETURN(cydriver.cuCtxPopCurrent(&ctx))
209 assert curr_ctx == ctx
210 HANDLE_RETURN(cydriver.cuCtxPushCurrent(target_ctx))
211 HANDLE_RETURN(cydriver.cuCtxGetDevice(&target_dev))
212 if switch_context:
213 HANDLE_RETURN(cydriver.cuCtxPopCurrent(&ctx))
214 assert target_ctx == ctx
215 HANDLE_RETURN(cydriver.cuCtxPushCurrent(curr_ctx))
216 return target_dev
219def is_sequence(obj):
220 """
221 Check if the given object is a sequence (list or tuple).
222 """
223 return isinstance(obj, Sequence)
226def is_nested_sequence(obj):
227 """
228 Check if the given object is a nested sequence (list or tuple with atleast one list or tuple element).
229 """
230 return is_sequence(obj) and any(is_sequence(elem) for elem in obj)
233@functools.lru_cache
234def get_binding_version():
235 try:
236 major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2]
237 except importlib.metadata.PackageNotFoundError:
238 major_minor = importlib.metadata.version("cuda-python").split(".")[:2]
239 return tuple(int(v) for v in major_minor)
242class Transaction:
243 """
244 A context manager for transactional operations with undo capability.
246 The Transaction class allows you to register undo actions (callbacks) that will be executed
247 if the transaction is not committed before exiting the context. This is useful for managing
248 resources or operations that need to be rolled back in case of errors or early exits.
250 Usage:
251 with Transaction() as txn:
252 txn.append(some_cleanup_function, arg1, arg2)
253 # ... perform operations ...
254 txn.commit() # Disarm undo actions; nothing will be rolled back on exit
256 Methods:
257 append(fn, *args, **kwargs): Register an undo action to be called on rollback.
258 commit(): Disarm all undo actions; nothing will be rolled back on exit.
259 """
260 def __init__(self):
261 self._stack = ExitStack()
262 self._entered = False
264 def __enter__(self):
265 self._stack.__enter__()
266 self._entered = True
267 return self
269 def __exit__(self, exc_type, exc, tb):
270 # If exit callbacks remain, they'll run in LIFO order.
271 self._entered = False
272 return self._stack.__exit__(exc_type, exc, tb)
274 def append(self, fn, /, *args, **kwargs):
275 """
276 Register an undo action (runs if the with-block exits without commit()).
277 Values are bound now via partial so late mutations don't bite you.
278 """
279 if not self._entered:
280 raise RuntimeError("Transaction must be entered before append()")
281 self._stack.callback(partial(fn, *args, **kwargs))
283 def commit(self):
284 """
285 Disarm all undo actions. After this, exiting the with-block does nothing.
286 """
287 # pop_all() empties this stack so no callbacks are triggered on exit.
288 self._stack.pop_all()
291# Track whether we've already warned about fork method
292_fork_warning_checked = False
295def reset_fork_warning():
296 """Reset the fork warning check flag for testing purposes.
298 This function is intended for use in tests to allow multiple test runs
299 to check the warning behavior.
300 """
301 global _fork_warning_checked
302 _fork_warning_checked = False
305def check_multiprocessing_start_method():
306 """Check if multiprocessing start method is 'fork' and warn if so."""
307 global _fork_warning_checked
308 if _fork_warning_checked:
309 return
310 _fork_warning_checked = True
312 # Common warning message parts
313 common_message = (
314 "CUDA does not support. Forked subprocesses exhibit undefined behavior, "
315 "including failure to initialize CUDA contexts and devices. Set the start method "
316 "to 'spawn' before creating processes that use CUDA. "
317 "Use: multiprocessing.set_start_method('spawn')"
318 )
320 try:
321 start_method = multiprocessing.get_start_method()
322 if start_method == "fork":
323 message = f"multiprocessing start method is 'fork', which {common_message}"
324 warnings.warn(message, UserWarning, stacklevel=3)
325 except RuntimeError:
326 # get_start_method() can raise RuntimeError if start method hasn't been set
327 # In this case, default is 'fork' on Linux, so we should warn
328 if platform.system() == "Linux":
329 message = (
330 f"multiprocessing start method is not set and defaults to 'fork' on Linux, "
331 f"which {common_message}"
332 )
333 warnings.warn(message, UserWarning, stacklevel=3)