Coverage for cuda/bindings/utils/_ptx_utils.py: 78.57%
14 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4import re
6# Mapping based on the official PTX ISA <-> CUDA Release table
7# https://docs.nvidia.com/cuda/parallel-thread-execution/#release-notes-ptx-release-history
8_ptx_to_cuda = {
9 "1.0": (1, 0),
10 "1.1": (1, 1),
11 "1.2": (2, 0),
12 "1.3": (2, 1),
13 "1.4": (2, 2),
14 "2.0": (3, 0),
15 "2.1": (3, 1),
16 "2.2": (3, 2),
17 "2.3": (4, 0),
18 "3.0": (4, 1),
19 "3.1": (5, 0),
20 "3.2": (5, 5),
21 "4.0": (6, 0),
22 "4.1": (6, 5),
23 "4.2": (7, 0),
24 "4.3": (7, 5),
25 "5.0": (8, 0),
26 "6.0": (9, 0),
27 "6.1": (9, 1),
28 "6.2": (9, 2),
29 "6.3": (10, 0),
30 "6.4": (10, 1),
31 "6.5": (10, 2),
32 "7.0": (11, 0),
33 "7.1": (11, 1),
34 "7.2": (11, 2),
35 "7.3": (11, 3),
36 "7.4": (11, 4),
37 "7.5": (11, 5),
38 "7.6": (11, 6),
39 "7.7": (11, 7),
40 "7.8": (11, 8),
41 "8.0": (12, 0),
42 "8.1": (12, 1),
43 "8.2": (12, 2),
44 "8.3": (12, 3),
45 "8.4": (12, 4),
46 "8.5": (12, 5),
47 "8.6": (12, 7),
48 "8.7": (12, 8),
49 "8.8": (12, 9),
50 "9.0": (13, 0),
51 "9.1": (13, 1),
52 "9.2": (13, 2),
53 "9.3": (13, 3),
54}
57def get_minimal_required_cuda_ver_from_ptx_ver(ptx_version: str) -> int:
58 """
59 Maps the PTX ISA version to the minimal CUDA driver, nvPTXCompiler, or nvJitLink version
60 that is needed to load a PTX of the given ISA version.
62 Parameters
63 ----------
64 ptx_version : str
65 PTX ISA version as a string, e.g. "8.8" for PTX ISA 8.8. This is the ``.version``
66 directive in the PTX header.
68 Returns
69 -------
70 int
71 Minimal CUDA version as 1000 * major + 10 * minor, e.g. 12090 for CUDA 12.9.
73 Raises
74 ------
75 ValueError
76 If the PTX version is unknown.
78 Examples
79 --------
80 >>> get_minimal_required_driver_ver_from_ptx_ver("8.8")
81 12090
82 >>> get_minimal_required_driver_ver_from_ptx_ver("7.0")
83 11000
84 """
85 try: 1ab
86 major, minor = _ptx_to_cuda[ptx_version] 1ab
87 return 1000 * major + 10 * minor 1ab
88 except KeyError:
89 raise ValueError(f"Unknown or unsupported PTX ISA version: {ptx_version}") from None
92# Regex pattern to match .version directive and capture the version number
93# TODO: if import speed is a concern, consider lazy-initializing it.
94_ptx_ver_pattern = re.compile(r"\.version\s+([0-9]+\.[0-9]+)")
97def get_ptx_ver(ptx: str) -> str:
98 """
99 Extract the PTX ISA version string from PTX source code.
101 Parameters
102 ----------
103 ptx : str
104 The PTX assembly source code as a string.
106 Returns
107 -------
108 str
109 The PTX ISA version string, e.g., "8.8".
111 Raises
112 ------
113 ValueError
114 If the .version directive is not found in the PTX source.
116 Examples
117 --------
118 >>> ptx = r'''
119 ... .version 8.8
120 ... .target sm_86
121 ... .address_size 64
122 ...
123 ... .visible .entry test_kernel()
124 ... {
125 ... ret;
126 ... }
127 ... '''
128 >>> get_ptx_ver(ptx)
129 '8.8'
130 """
131 m = _ptx_ver_pattern.search(ptx) 1ab
132 if m: 1ab
133 return m.group(1) 1ab
134 else:
135 raise ValueError("No .version directive found in PTX source. Is it a valid PTX?")