Coverage for cuda / bindings / _example_helpers / helper_cuda.py: 0.00%

34 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-29 01:27 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE 

3 

4from cuda.bindings import driver as cuda 

5from cuda.bindings import nvrtc 

6from cuda.bindings import runtime as cudart 

7 

8from .helper_string import check_cmd_line_flag, get_cmd_line_argument_int 

9 

10 

11def _cuda_get_error_enum(error): 

12 if isinstance(error, cuda.CUresult): 

13 err, name = cuda.cuGetErrorName(error) 

14 return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>" 

15 elif isinstance(error, cudart.cudaError_t): 

16 return cudart.cudaGetErrorName(error)[1] 

17 elif isinstance(error, nvrtc.nvrtcResult): 

18 return nvrtc.nvrtcGetErrorString(error)[1] 

19 else: 

20 raise RuntimeError(f"Unknown error type: {error}") 

21 

22 

23def check_cuda_errors(result): 

24 if result[0].value: 

25 raise RuntimeError(f"CUDA error code={result[0].value}({_cuda_get_error_enum(result[0])})") 

26 if len(result) == 1: 

27 return None 

28 elif len(result) == 2: 

29 return result[1] 

30 else: 

31 return result[1:] 

32 

33 

34def find_cuda_device(): 

35 dev_id = 0 

36 if check_cmd_line_flag("device="): 

37 dev_id = get_cmd_line_argument_int("device=") 

38 check_cuda_errors(cudart.cudaSetDevice(dev_id)) 

39 return dev_id 

40 

41 

42def find_cuda_device_drv(): 

43 dev_id = 0 

44 if check_cmd_line_flag("device="): 

45 dev_id = get_cmd_line_argument_int("device=") 

46 check_cuda_errors(cuda.cuInit(0)) 

47 cu_device = check_cuda_errors(cuda.cuDeviceGet(dev_id)) 

48 return cu_device