Coverage for cuda/bindings/utils/_ptx_utils.py: 78.57%

14 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-13 01:38 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE 

3 

4import re 

5 

6# Mapping based on the official PTX ISA <-> CUDA Release table 

7# https://docs.nvidia.com/cuda/parallel-thread-execution/#release-notes-ptx-release-history 

8_ptx_to_cuda = { 

9 "1.0": (1, 0), 

10 "1.1": (1, 1), 

11 "1.2": (2, 0), 

12 "1.3": (2, 1), 

13 "1.4": (2, 2), 

14 "2.0": (3, 0), 

15 "2.1": (3, 1), 

16 "2.2": (3, 2), 

17 "2.3": (4, 0), 

18 "3.0": (4, 1), 

19 "3.1": (5, 0), 

20 "3.2": (5, 5), 

21 "4.0": (6, 0), 

22 "4.1": (6, 5), 

23 "4.2": (7, 0), 

24 "4.3": (7, 5), 

25 "5.0": (8, 0), 

26 "6.0": (9, 0), 

27 "6.1": (9, 1), 

28 "6.2": (9, 2), 

29 "6.3": (10, 0), 

30 "6.4": (10, 1), 

31 "6.5": (10, 2), 

32 "7.0": (11, 0), 

33 "7.1": (11, 1), 

34 "7.2": (11, 2), 

35 "7.3": (11, 3), 

36 "7.4": (11, 4), 

37 "7.5": (11, 5), 

38 "7.6": (11, 6), 

39 "7.7": (11, 7), 

40 "7.8": (11, 8), 

41 "8.0": (12, 0), 

42 "8.1": (12, 1), 

43 "8.2": (12, 2), 

44 "8.3": (12, 3), 

45 "8.4": (12, 4), 

46 "8.5": (12, 5), 

47 "8.6": (12, 7), 

48 "8.7": (12, 8), 

49 "8.8": (12, 9), 

50 "9.0": (13, 0), 

51 "9.1": (13, 1), 

52 "9.2": (13, 2), 

53 "9.3": (13, 3), 

54} 

55 

56 

57def get_minimal_required_cuda_ver_from_ptx_ver(ptx_version: str) -> int: 

58 """ 

59 Maps the PTX ISA version to the minimal CUDA driver, nvPTXCompiler, or nvJitLink version 

60 that is needed to load a PTX of the given ISA version. 

61 

62 Parameters 

63 ---------- 

64 ptx_version : str 

65 PTX ISA version as a string, e.g. "8.8" for PTX ISA 8.8. This is the ``.version`` 

66 directive in the PTX header. 

67 

68 Returns 

69 ------- 

70 int 

71 Minimal CUDA version as 1000 * major + 10 * minor, e.g. 12090 for CUDA 12.9. 

72 

73 Raises 

74 ------ 

75 ValueError 

76 If the PTX version is unknown. 

77 

78 Examples 

79 -------- 

80 >>> get_minimal_required_driver_ver_from_ptx_ver("8.8") 

81 12090 

82 >>> get_minimal_required_driver_ver_from_ptx_ver("7.0") 

83 11000 

84 """ 

85 try: 1ab

86 major, minor = _ptx_to_cuda[ptx_version] 1ab

87 return 1000 * major + 10 * minor 1ab

88 except KeyError: 

89 raise ValueError(f"Unknown or unsupported PTX ISA version: {ptx_version}") from None 

90 

91 

92# Regex pattern to match .version directive and capture the version number 

93# TODO: if import speed is a concern, consider lazy-initializing it. 

94_ptx_ver_pattern = re.compile(r"\.version\s+([0-9]+\.[0-9]+)") 

95 

96 

97def get_ptx_ver(ptx: str) -> str: 

98 """ 

99 Extract the PTX ISA version string from PTX source code. 

100 

101 Parameters 

102 ---------- 

103 ptx : str 

104 The PTX assembly source code as a string. 

105 

106 Returns 

107 ------- 

108 str 

109 The PTX ISA version string, e.g., "8.8". 

110 

111 Raises 

112 ------ 

113 ValueError 

114 If the .version directive is not found in the PTX source. 

115 

116 Examples 

117 -------- 

118 >>> ptx = r''' 

119 ... .version 8.8 

120 ... .target sm_86 

121 ... .address_size 64 

122 ... 

123 ... .visible .entry test_kernel() 

124 ... { 

125 ... ret; 

126 ... } 

127 ... ''' 

128 >>> get_ptx_ver(ptx) 

129 '8.8' 

130 """ 

131 m = _ptx_ver_pattern.search(ptx) 1ab

132 if m: 1ab

133 return m.group(1) 1ab

134 else: 

135 raise ValueError("No .version directive found in PTX source. Is it a valid PTX?")