Coverage for cuda / pathfinder / _utils / env_vars.py: 94.29%
35 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
4"""Centralized CUDA environment variable handling.
6This module defines the canonical search order for CUDA Toolkit environment variables
7used throughout cuda-python packages (cuda.pathfinder, cuda.core, cuda.bindings).
9Search Order Priority:
10 1. CUDA_PATH (higher priority)
11 2. CUDA_HOME (lower priority)
13If both are set and differ, CUDA_PATH takes precedence and a warning is issued.
15Important Note on Caching:
16 The result of get_cuda_path_or_home() is cached for the process lifetime. The first
17 call determines the CUDA Toolkit path, and all subsequent calls return the cached
18 value, even if environment variables change later. This ensures consistent behavior
19 throughout the application lifecycle.
20"""
22import functools
23import os
24import warnings
26_CUDA_PATH_ENV_VARS_ORDERED = ("CUDA_PATH", "CUDA_HOME")
29def _paths_differ(a: str, b: str) -> bool:
30 """
31 Return True if paths are observably different.
33 Strategy:
34 1) Compare os.path.normcase(os.path.normpath(...)) for quick, robust textual equality.
35 - Handles trailing slashes and case-insensitivity on Windows.
36 2) If still different AND both exist, use os.path.samefile to resolve symlinks/junctions.
37 3) Otherwise (nonexistent paths or samefile unavailable), treat as different.
38 """
39 norm_a = os.path.normcase(os.path.normpath(a)) 1aghijklmnopqrbBAcfde
40 norm_b = os.path.normcase(os.path.normpath(b)) 1aghijklmnopqrbBAcfde
41 if norm_a == norm_b: 1aghijklmnopqrbBAcfde
42 return False 1aghijklmnopqrA
44 try: 1bBAcfde
45 if os.path.exists(a) and os.path.exists(b): 1bBAcfde
46 # samefile raises on non-existent paths; only call when both exist.
47 return not os.path.samefile(a, b) 1Bcfde
48 except OSError:
49 # Fall through to "different" if samefile isn't applicable/available.
50 pass
52 # If normalized strings differ and we couldn't prove they're the same entry, treat as different.
53 return True 1bA
56@functools.cache
57def get_cuda_path_or_home() -> str | None:
58 """Get CUDA Toolkit path from environment variables.
60 Returns the value of CUDA_PATH or CUDA_HOME. If both are set and differ,
61 CUDA_PATH takes precedence and a warning is issued.
63 The result is cached for the process lifetime. The first call determines the CUDA
64 Toolkit path, and subsequent calls return the cached value.
66 Returns:
67 Path to CUDA Toolkit, or None if neither variable is set or all are empty.
69 Warnings:
70 UserWarning: If multiple CUDA environment variables are set but point to
71 different locations (only on the first call).
73 """
74 # Collect non-empty environment variables in priority order.
75 # Empty strings are treated as undefined — no valid CUDA path is empty.
76 set_vars = {} 1asCgthijklDEuFmnvGowpHxyqrbcIfdze
77 for var in _CUDA_PATH_ENV_VARS_ORDERED: 1asCgthijklDEuFmnvGowpHxyqrbcIfdze
78 val = os.environ.get(var) 1asCgthijklDEuFmnvGowpHxyqrbcIfdze
79 if val: 1asCgthijklDEuFmnvGowpHxyqrbcIfdze
80 set_vars[var] = val 1asgthijklumnvowpxyqrbcfdze
82 if not set_vars: 1asCgthijklDEuFmnvGowpHxyqrbcIfdze
83 return None 1CDEFGHI
85 # If multiple variables are set, check if they differ and warn
86 if len(set_vars) > 1: 1asgthijklumnvowpxyqrbcfdze
87 values = list(set_vars.items()) 1aghijklmnopqrbcfde
88 values_differ = False 1aghijklmnopqrbcfde
89 for i in range(len(values) - 1): 1aghijklmnopqrbcfde
90 if _paths_differ(values[i][1], values[i + 1][1]): 1aghijklmnopqrbcfde
91 values_differ = True 1bcde
92 break 1bcde
94 if values_differ: 1aghijklmnopqrbcfde
95 var_list = "\n".join(f" {var}={val}" for var, val in set_vars.items()) 1bcde
96 warnings.warn( 1bcde
97 f"Multiple CUDA environment variables are set but differ:\n"
98 f"{var_list}\n"
99 f"Using {_CUDA_PATH_ENV_VARS_ORDERED[0]} (highest priority).",
100 UserWarning,
101 stacklevel=2,
102 )
104 # Return the first (highest priority) set variable
105 return next(iter(set_vars.values())) 1asgthijklumnvowpxyqrbcfdze