Coverage for cuda / pathfinder / _headers / find_nvidia_headers.py: 68%
88 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
4import functools
5import glob
6import os
8from cuda.pathfinder._headers import supported_nvidia_headers
9from cuda.pathfinder._utils.env_vars import get_cuda_home_or_path
10from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages
11from cuda.pathfinder._utils.platform_aware import IS_WINDOWS
14def _abs_norm(path: str | None) -> str | None:
15 if path:
16 return os.path.normpath(os.path.abspath(path))
17 return None
20def _joined_isfile(dirpath: str, basename: str) -> bool:
21 return os.path.isfile(os.path.join(dirpath, basename))
24def _find_under_site_packages(sub_dir: str, h_basename: str) -> str | None:
25 # Installed from a wheel
26 hdr_dir: str # help mypy
27 for hdr_dir in find_sub_dirs_all_sitepackages(tuple(sub_dir.split("/"))):
28 if _joined_isfile(hdr_dir, h_basename):
29 return hdr_dir
30 return None
33def _find_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str) -> str | None:
34 parts = [anchor_point]
35 if libname == "nvvm":
36 parts.append(libname)
37 parts.append("include")
38 idir = os.path.join(*parts)
39 if libname == "cccl":
40 if IS_WINDOWS:
41 cdir_ctk12 = os.path.join(idir, "targets", "x64") # conda has this anomaly
42 cdir_ctk13 = os.path.join(cdir_ctk12, "cccl")
43 if _joined_isfile(cdir_ctk13, h_basename):
44 return cdir_ctk13
45 if _joined_isfile(cdir_ctk12, h_basename):
46 return cdir_ctk12
47 cdir = os.path.join(idir, "cccl") # CTK 13
48 if _joined_isfile(cdir, h_basename):
49 return cdir
50 if _joined_isfile(idir, h_basename):
51 return idir
52 return None
55def _find_based_on_conda_layout(libname: str, h_basename: str, ctk_layout: bool) -> str | None:
56 conda_prefix = os.environ.get("CONDA_PREFIX")
57 if not conda_prefix:
58 return None
59 if IS_WINDOWS:
60 anchor_point = os.path.join(conda_prefix, "Library")
61 if not os.path.isdir(anchor_point):
62 return None
63 else:
64 if ctk_layout:
65 targets_include_path = glob.glob(os.path.join(conda_prefix, "targets", "*", "include"))
66 if not targets_include_path:
67 return None
68 if len(targets_include_path) != 1:
69 # Conda does not support multiple architectures.
70 # QUESTION(PR#956): Do we want to issue a warning?
71 return None
72 include_path = targets_include_path[0]
73 else:
74 include_path = os.path.join(conda_prefix, "include")
75 anchor_point = os.path.dirname(include_path)
76 return _find_based_on_ctk_layout(libname, h_basename, anchor_point)
79def _find_ctk_header_directory(libname: str) -> str | None:
80 h_basename = supported_nvidia_headers.SUPPORTED_HEADERS_CTK[libname]
81 candidate_dirs = supported_nvidia_headers.SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK[libname]
83 for cdir in candidate_dirs:
84 if hdr_dir := _find_under_site_packages(cdir, h_basename):
85 return hdr_dir
87 if hdr_dir := _find_based_on_conda_layout(libname, h_basename, True):
88 return hdr_dir
90 cuda_home = get_cuda_home_or_path()
91 if cuda_home: # noqa: SIM102
92 if result := _find_based_on_ctk_layout(libname, h_basename, cuda_home):
93 return result
95 return None
98@functools.cache
99def find_nvidia_header_directory(libname: str) -> str | None:
100 """Locate the header directory for a supported NVIDIA library.
102 Args:
103 libname (str): The short name of the library whose headers are needed
104 (e.g., ``"nvrtc"``, ``"cusolver"``, ``"nvshmem"``).
106 Returns:
107 str or None: Absolute path to the discovered header directory, or ``None``
108 if the headers cannot be found.
110 Raises:
111 RuntimeError: If ``libname`` is not in the supported set.
113 Search order:
114 1. **NVIDIA Python wheels**
116 - Scan installed distributions (``site-packages``) for header layouts
117 shipped in NVIDIA wheels (e.g., ``cuda-toolkit[nvrtc]``).
119 2. **Conda environments**
121 - Check Conda-style installation prefixes, which use platform-specific
122 include directory layouts.
124 3. **CUDA Toolkit environment variables**
126 - Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order).
127 """
129 if libname in supported_nvidia_headers.SUPPORTED_HEADERS_CTK:
130 return _abs_norm(_find_ctk_header_directory(libname))
132 h_basename = supported_nvidia_headers.SUPPORTED_HEADERS_NON_CTK.get(libname)
133 if h_basename is None:
134 raise RuntimeError(f"UNKNOWN {libname=}")
136 candidate_dirs = supported_nvidia_headers.SUPPORTED_SITE_PACKAGE_HEADER_DIRS_NON_CTK.get(libname, [])
137 hdr_dir: str | None # help mypy
138 for cdir in candidate_dirs:
139 if hdr_dir := _find_under_site_packages(cdir, h_basename):
140 return _abs_norm(hdr_dir)
142 if hdr_dir := _find_based_on_conda_layout(libname, h_basename, False):
143 return _abs_norm(hdr_dir)
145 candidate_dirs = supported_nvidia_headers.SUPPORTED_INSTALL_DIRS_NON_CTK.get(libname, [])
146 for cdir in candidate_dirs:
147 for hdr_dir in sorted(glob.glob(cdir), reverse=True):
148 if _joined_isfile(hdr_dir, h_basename):
149 return _abs_norm(hdr_dir)
151 return None