Coverage for cuda / bindings / _example_helpers / helper_cuda.py: 0.00%
34 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4from cuda.bindings import driver as cuda
5from cuda.bindings import nvrtc
6from cuda.bindings import runtime as cudart
8from .helper_string import check_cmd_line_flag, get_cmd_line_argument_int
11def _cuda_get_error_enum(error):
12 if isinstance(error, cuda.CUresult):
13 err, name = cuda.cuGetErrorName(error)
14 return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
15 elif isinstance(error, cudart.cudaError_t):
16 return cudart.cudaGetErrorName(error)[1]
17 elif isinstance(error, nvrtc.nvrtcResult):
18 return nvrtc.nvrtcGetErrorString(error)[1]
19 else:
20 raise RuntimeError(f"Unknown error type: {error}")
23def check_cuda_errors(result):
24 if result[0].value:
25 raise RuntimeError(f"CUDA error code={result[0].value}({_cuda_get_error_enum(result[0])})")
26 if len(result) == 1:
27 return None
28 elif len(result) == 2:
29 return result[1]
30 else:
31 return result[1:]
34def find_cuda_device():
35 dev_id = 0
36 if check_cmd_line_flag("device="):
37 dev_id = get_cmd_line_argument_int("device=")
38 check_cuda_errors(cudart.cudaSetDevice(dev_id))
39 return dev_id
42def find_cuda_device_drv():
43 dev_id = 0
44 if check_cmd_line_flag("device="):
45 dev_id = get_cmd_line_argument_int("device=")
46 check_cuda_errors(cuda.cuInit(0))
47 cu_device = check_cuda_errors(cuda.cuDeviceGet(dev_id))
48 return cu_device