Coverage for cuda / bindings / utils / _ptx_utils.py: 79%
14 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4import re
6# Mapping based on the official PTX ISA <-> CUDA Release table
7# https://docs.nvidia.com/cuda/parallel-thread-execution/#release-notes-ptx-release-history
8_ptx_to_cuda = {
9 "1.0": (1, 0),
10 "1.1": (1, 1),
11 "1.2": (2, 0),
12 "1.3": (2, 1),
13 "1.4": (2, 2),
14 "2.0": (3, 0),
15 "2.1": (3, 1),
16 "2.2": (3, 2),
17 "2.3": (4, 0),
18 "3.0": (4, 1),
19 "3.1": (5, 0),
20 "3.2": (5, 5),
21 "4.0": (6, 0),
22 "4.1": (6, 5),
23 "4.2": (7, 0),
24 "4.3": (7, 5),
25 "5.0": (8, 0),
26 "6.0": (9, 0),
27 "6.1": (9, 1),
28 "6.2": (9, 2),
29 "6.3": (10, 0),
30 "6.4": (10, 1),
31 "6.5": (10, 2),
32 "7.0": (11, 0),
33 "7.1": (11, 1),
34 "7.2": (11, 2),
35 "7.3": (11, 3),
36 "7.4": (11, 4),
37 "7.5": (11, 5),
38 "7.6": (11, 6),
39 "7.7": (11, 7),
40 "7.8": (11, 8),
41 "8.0": (12, 0),
42 "8.1": (12, 1),
43 "8.2": (12, 2),
44 "8.3": (12, 3),
45 "8.4": (12, 4),
46 "8.5": (12, 5),
47 "8.6": (12, 7),
48 "8.7": (12, 8),
49 "8.8": (12, 9),
50 "9.0": (13, 0),
51 "9.1": (13, 1),
52}
55def get_minimal_required_cuda_ver_from_ptx_ver(ptx_version: str) -> int:
56 """
57 Maps the PTX ISA version to the minimal CUDA driver, nvPTXCompiler, or nvJitLink version
58 that is needed to load a PTX of the given ISA version.
60 Parameters
61 ----------
62 ptx_version : str
63 PTX ISA version as a string, e.g. "8.8" for PTX ISA 8.8. This is the ``.version``
64 directive in the PTX header.
66 Returns
67 -------
68 int
69 Minimal CUDA version as 1000 * major + 10 * minor, e.g. 12090 for CUDA 12.9.
71 Raises
72 ------
73 ValueError
74 If the PTX version is unknown.
76 Examples
77 --------
78 >>> get_minimal_required_driver_ver_from_ptx_ver("8.8")
79 12090
80 >>> get_minimal_required_driver_ver_from_ptx_ver("7.0")
81 11000
82 """
83 try:
84 major, minor = _ptx_to_cuda[ptx_version]
85 return 1000 * major + 10 * minor
86 except KeyError:
87 raise ValueError(f"Unknown or unsupported PTX ISA version: {ptx_version}") from None
90# Regex pattern to match .version directive and capture the version number
91# TODO: if import speed is a concern, consider lazy-initializing it.
92_ptx_ver_pattern = re.compile(r"\.version\s+([0-9]+\.[0-9]+)")
95def get_ptx_ver(ptx: str) -> str:
96 """
97 Extract the PTX ISA version string from PTX source code.
99 Parameters
100 ----------
101 ptx : str
102 The PTX assembly source code as a string.
104 Returns
105 -------
106 str
107 The PTX ISA version string, e.g., "8.8".
109 Raises
110 ------
111 ValueError
112 If the .version directive is not found in the PTX source.
114 Examples
115 --------
116 >>> ptx = r'''
117 ... .version 8.8
118 ... .target sm_86
119 ... .address_size 64
120 ...
121 ... .visible .entry test_kernel()
122 ... {
123 ... ret;
124 ... }
125 ... '''
126 >>> get_ptx_ver(ptx)
127 '8.8'
128 """
129 m = _ptx_ver_pattern.search(ptx)
130 if m:
131 return m.group(1)
132 else:
133 raise ValueError("No .version directive found in PTX source. Is it a valid PTX?")