Coverage for cuda / core / experimental / _utils / cuda_utils.pyx: 79%

180 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-10 01:19 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5import functools 

6from functools import partial 

7import importlib.metadata 

8import multiprocessing 

9import platform 

10import warnings 

11from collections import namedtuple 

12from collections.abc import Sequence 

13from contextlib import ExitStack 

14from typing import Callable 

15  

16try: 

17 from cuda.bindings import driver, nvrtc, runtime 

18except ImportError: 

19 from cuda import cuda as driver 

20 from cuda import cudart as runtime 

21 from cuda import nvrtc 

22  

23from cuda.core.experimental._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS 

24from cuda.core.experimental._utils.runtime_cuda_error_explanations import RUNTIME_CUDA_ERROR_EXPLANATIONS 

25  

26  

27class CUDAError(Exception): 

28 pass 

29  

30  

31class NVRTCError(CUDAError): 

32 pass 

33  

34  

35ComputeCapability = namedtuple("ComputeCapability", ("major", "minor")) 

36  

37  

38def cast_to_3_tuple(label, cfg): 

39 cfg_orig = cfg 

40 if isinstance(cfg, int): 

41 cfg = (cfg,) 

42 else: 

43 common = "must be an int, or a tuple with up to 3 ints" 

44 if not isinstance(cfg, tuple): 

45 raise ValueError(f"{label} {common} (got {type(cfg)})") 

46 if len(cfg) > 3: 

47 raise ValueError(f"{label} {common} (got tuple with length {len(cfg)})") 

48 if any(not isinstance(val, int) for val in cfg): 

49 raise ValueError(f"{label} {common} (got {cfg})") 

50 if any(val < 1 for val in cfg): 

51 plural_s = "" if len(cfg) == 1 else "s" 

52 raise ValueError(f"{label} value{plural_s} must be >= 1 (got {cfg_orig})") 

53 return cfg + (1,) * (3 - len(cfg)) 

54  

55  

56def _reduce_3_tuple(t: tuple): 

57 return t[0] * t[1] * t[2] 

58  

59  

60cdef int HANDLE_RETURN(supported_error_type err) except?-1 nogil: 

61 if supported_error_type is cydriver.CUresult: 

62 if err != cydriver.CUresult.CUDA_SUCCESS: 

63 return _check_driver_error(err) 

64  

65  

66cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess 

67cdef object _NVRTC_SUCCESS = nvrtc.nvrtcResult.NVRTC_SUCCESS 

68  

69  

70cpdef inline int _check_driver_error(cydriver.CUresult error) except?-1 nogil: 

71 if error == cydriver.CUresult.CUDA_SUCCESS: 

72 return 0 

73 cdef const char* name 

74 name_err = cydriver.cuGetErrorName(error, &name) 

75 if name_err != cydriver.CUresult.CUDA_SUCCESS: 

76 raise CUDAError(f"UNEXPECTED ERROR CODE: {error}") 

77 with gil: 

78 # TODO: consider lower this to Cython 

79 expl = DRIVER_CU_RESULT_EXPLANATIONS.get(int(error)) 

80 if expl is not None: 

81 raise CUDAError(f"{name.decode()}: {expl}") 

82 cdef const char* desc 

83 desc_err = cydriver.cuGetErrorString(error, &desc) 

84 if desc_err != cydriver.CUresult.CUDA_SUCCESS: 

85 raise CUDAError(f"{name.decode()}") 

86 raise CUDAError(f"{name.decode()}: {desc.decode()}") 

87  

88  

89cpdef inline int _check_runtime_error(error) except?-1: 

90 if error == _RUNTIME_SUCCESS: 

91 return 0 

92 name_err, name = runtime.cudaGetErrorName(error) 

93 if name_err != _RUNTIME_SUCCESS: 

94 raise CUDAError(f"UNEXPECTED ERROR CODE: {error}") 

95 name = name.decode() 

96 expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int(error)) 

97 if expl is not None: 

98 raise CUDAError(f"{name}: {expl}") 

99 desc_err, desc = runtime.cudaGetErrorString(error) 

100 if desc_err != _RUNTIME_SUCCESS: 

101 raise CUDAError(f"{name}") 

102 desc = desc.decode() 

103 raise CUDAError(f"{name}: {desc}") 

104  

105  

106cpdef inline int _check_nvrtc_error(error, handle=None) except?-1: 

107 if error == _NVRTC_SUCCESS: 

108 return 0 

109 err = f"{error}: {nvrtc.nvrtcGetErrorString(error)[1].decode()}" 

110 if handle is not None: 

111 _, logsize = nvrtc.nvrtcGetProgramLogSize(handle) 

112 log = b" " * logsize 

113 _ = nvrtc.nvrtcGetProgramLog(handle, log) 

114 err += f", compilation log:\n\n{log.decode('utf-8', errors='backslashreplace')}" 

115 raise NVRTCError(err) 

116  

117  

118cdef inline int _check_error(error, handle=None) except?-1: 

119 if isinstance(error, driver.CUresult): 

120 return _check_driver_error(error) 

121 elif isinstance(error, runtime.cudaError_t): 

122 return _check_runtime_error(error) 

123 elif isinstance(error, nvrtc.nvrtcResult): 

124 return _check_nvrtc_error(error, handle=handle) 

125 else: 

126 raise RuntimeError(f"Unknown error type: {error}") 

127  

128  

129def handle_return(tuple result, handle=None): 

130 _check_error(result[0], handle=handle) 

131 cdef int out_len = len(result) 

132 if out_len == 1: 

133 return 

134 elif out_len == 2: 

135 return result[1] 

136 else: 

137 return result[1:] 

138  

139  

140cpdef check_or_create_options(type cls, options, str options_description="", bint keep_none=False): 

141 """ 

142 Create the specified options dataclass from a dictionary of options or None. 

143 """ 

144 if options is None: 

145 if keep_none: 

146 return options 

147 return cls() 

148 elif isinstance(options, cls): 

149 return options 

150 elif isinstance(options, dict): 

151 return cls(**options) 

152 else: 

153 raise TypeError( 

154 f"The {options_description} must be provided as an object " 

155 f"of type {cls.__name__} or as a dict with valid {options_description}. " 

156 f"The provided object is '{options}'." 

157 ) 

158  

159  

160def _handle_boolean_option(option: bool) -> str: 

161 """ 

162 Convert a boolean option to a string representation. 

163 """ 

164 return "true" if bool(option) else "false" 

165  

166  

167def precondition(checker: Callable[..., None], str what="") -> Callable: 

168 """ 

169 A decorator that adds checks to ensure any preconditions are met. 

170  

171 Args: 

172 checker: The function to call to check whether the preconditions are met. It has 

173 the same signature as the wrapped function with the addition of the keyword argument `what`. 

174 what: A string that is passed in to `checker` to provide context information. 

175  

176 Returns: 

177 Callable: A decorator that creates the wrapping. 

178 """ 

179  

180 def outer(wrapped_function): 

181 """ 

182 A decorator that actually wraps the function for checking preconditions. 

183 """ 

184  

185 @functools.wraps(wrapped_function) 

186 def inner(*args, **kwargs): 

187 """ 

188 Check preconditions and if they are met, call the wrapped function. 

189 """ 

190 checker(*args, **kwargs, what=what) 

191 result = wrapped_function(*args, **kwargs) 

192  

193 return result 

194  

195 return inner 

196  

197 return outer 

198  

199  

200cdef cydriver.CUdevice get_device_from_ctx( 

201 cydriver.CUcontext target_ctx, cydriver.CUcontext curr_ctx) except?cydriver.CU_DEVICE_INVALID nogil: 

202 """Get device ID from the given ctx.""" 

203 cdef bint switch_context = (curr_ctx != target_ctx) 

204 cdef cydriver.CUcontext ctx 

205 cdef cydriver.CUdevice target_dev 

206 with nogil: 

207 if switch_context: 

208 HANDLE_RETURN(cydriver.cuCtxPopCurrent(&ctx)) 

209 assert curr_ctx == ctx 

210 HANDLE_RETURN(cydriver.cuCtxPushCurrent(target_ctx)) 

211 HANDLE_RETURN(cydriver.cuCtxGetDevice(&target_dev)) 

212 if switch_context: 

213 HANDLE_RETURN(cydriver.cuCtxPopCurrent(&ctx)) 

214 assert target_ctx == ctx 

215 HANDLE_RETURN(cydriver.cuCtxPushCurrent(curr_ctx)) 

216 return target_dev 

217  

218  

219def is_sequence(obj): 

220 """ 

221 Check if the given object is a sequence (list or tuple). 

222 """ 

223 return isinstance(obj, Sequence) 

224  

225  

226def is_nested_sequence(obj): 

227 """ 

228 Check if the given object is a nested sequence (list or tuple with atleast one list or tuple element). 

229 """ 

230 return is_sequence(obj) and any(is_sequence(elem) for elem in obj) 

231  

232  

233@functools.lru_cache 

234def get_binding_version(): 

235 try: 

236 major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2] 

237 except importlib.metadata.PackageNotFoundError: 

238 major_minor = importlib.metadata.version("cuda-python").split(".")[:2] 

239 return tuple(int(v) for v in major_minor) 

240  

241  

242class Transaction: 

243 """ 

244 A context manager for transactional operations with undo capability. 

245  

246 The Transaction class allows you to register undo actions (callbacks) that will be executed 

247 if the transaction is not committed before exiting the context. This is useful for managing 

248 resources or operations that need to be rolled back in case of errors or early exits. 

249  

250 Usage: 

251 with Transaction() as txn: 

252 txn.append(some_cleanup_function, arg1, arg2) 

253 # ... perform operations ... 

254 txn.commit() # Disarm undo actions; nothing will be rolled back on exit 

255  

256 Methods: 

257 append(fn, *args, **kwargs): Register an undo action to be called on rollback. 

258 commit(): Disarm all undo actions; nothing will be rolled back on exit. 

259 """ 

260 def __init__(self): 

261 self._stack = ExitStack() 

262 self._entered = False 

263  

264 def __enter__(self): 

265 self._stack.__enter__() 

266 self._entered = True 

267 return self 

268  

269 def __exit__(self, exc_type, exc, tb): 

270 # If exit callbacks remain, they'll run in LIFO order. 

271 self._entered = False 

272 return self._stack.__exit__(exc_type, exc, tb) 

273  

274 def append(self, fn, /, *args, **kwargs): 

275 """ 

276 Register an undo action (runs if the with-block exits without commit()). 

277 Values are bound now via partial so late mutations don't bite you. 

278 """ 

279 if not self._entered: 

280 raise RuntimeError("Transaction must be entered before append()") 

281 self._stack.callback(partial(fn, *args, **kwargs)) 

282  

283 def commit(self): 

284 """ 

285 Disarm all undo actions. After this, exiting the with-block does nothing. 

286 """ 

287 # pop_all() empties this stack so no callbacks are triggered on exit. 

288 self._stack.pop_all() 

289  

290  

291# Track whether we've already warned about fork method 

292_fork_warning_checked = False 

293  

294  

295def reset_fork_warning(): 

296 """Reset the fork warning check flag for testing purposes. 

297  

298 This function is intended for use in tests to allow multiple test runs 

299 to check the warning behavior. 

300 """ 

301 global _fork_warning_checked 

302 _fork_warning_checked = False 

303  

304  

305def check_multiprocessing_start_method(): 

306 """Check if multiprocessing start method is 'fork' and warn if so.""" 

307 global _fork_warning_checked 

308 if _fork_warning_checked: 

309 return 

310 _fork_warning_checked = True 

311  

312 # Common warning message parts 

313 common_message = ( 

314 "CUDA does not support. Forked subprocesses exhibit undefined behavior, " 

315 "including failure to initialize CUDA contexts and devices. Set the start method " 

316 "to 'spawn' before creating processes that use CUDA. " 

317 "Use: multiprocessing.set_start_method('spawn')" 

318 ) 

319  

320 try: 

321 start_method = multiprocessing.get_start_method() 

322 if start_method == "fork": 

323 message = f"multiprocessing start method is 'fork', which {common_message}" 

324 warnings.warn(message, UserWarning, stacklevel=3) 

325 except RuntimeError: 

326 # get_start_method() can raise RuntimeError if start method hasn't been set 

327 # In this case, default is 'fork' on Linux, so we should warn 

328 if platform.system() == "Linux": 

329 message = ( 

330 f"multiprocessing start method is not set and defaults to 'fork' on Linux, " 

331 f"which {common_message}" 

332 ) 

333 warnings.warn(message, UserWarning, stacklevel=3)