Coverage for cuda / pathfinder / _dynamic_libs / load_dl_windows.py: 78.57%
70 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
6import ctypes
7import ctypes.wintypes
8import os
9import struct
10from typing import TYPE_CHECKING
12from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
14if TYPE_CHECKING:
15 from cuda.pathfinder._dynamic_libs.lib_descriptor import LibDescriptor
17# Mirrors WinBase.h (unfortunately not defined already elsewhere)
18WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
19WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
21POINTER_ADDRESS_SPACE = 2 ** (struct.calcsize("P") * 8)
23# Set up kernel32 functions with proper types
24kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
26# GetModuleHandleW
27kernel32.GetModuleHandleW.argtypes = [ctypes.wintypes.LPCWSTR]
28kernel32.GetModuleHandleW.restype = ctypes.wintypes.HMODULE
30# LoadLibraryExW
31kernel32.LoadLibraryExW.argtypes = [
32 ctypes.wintypes.LPCWSTR, # lpLibFileName
33 ctypes.wintypes.HANDLE, # hFile (reserved, must be NULL)
34 ctypes.wintypes.DWORD, # dwFlags
35]
36kernel32.LoadLibraryExW.restype = ctypes.wintypes.HMODULE
38# GetModuleFileNameW
39kernel32.GetModuleFileNameW.argtypes = [
40 ctypes.wintypes.HMODULE, # hModule
41 ctypes.wintypes.LPWSTR, # lpFilename
42 ctypes.wintypes.DWORD, # nSize
43]
44kernel32.GetModuleFileNameW.restype = ctypes.wintypes.DWORD
46# AddDllDirectory (Windows 7+)
47kernel32.AddDllDirectory.argtypes = [ctypes.wintypes.LPCWSTR]
48kernel32.AddDllDirectory.restype = ctypes.c_void_p # DLL_DIRECTORY_COOKIE
51def ctypes_handle_to_unsigned_int(handle: ctypes.wintypes.HMODULE) -> int:
52 """Convert ctypes HMODULE to unsigned int."""
53 handle_uint = int(handle) 1abc
54 if handle_uint < 0: 1abc
55 # Convert from signed to unsigned representation
56 handle_uint += POINTER_ADDRESS_SPACE
57 return handle_uint 1abc
60def add_dll_directory(dll_abs_path: str) -> None:
61 """Add a DLL directory to the search path and update PATH environment variable.
63 Args:
64 dll_abs_path: Absolute path to the DLL file
66 Raises:
67 AssertionError: If the directory containing the DLL does not exist
68 """
69 dirpath = os.path.dirname(dll_abs_path)
70 assert os.path.isdir(dirpath), dll_abs_path
72 # Add the DLL directory to the search path
73 result = kernel32.AddDllDirectory(dirpath)
74 if not result:
75 # Fallback: just update PATH if AddDllDirectory fails
76 pass
78 # Update PATH as a fallback for dependent DLL resolution
79 curr_path = os.environ.get("PATH")
80 os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
83def abs_path_for_dynamic_library(libname: str, handle: ctypes.wintypes.HMODULE) -> str:
84 """Get the absolute path of a loaded dynamic library on Windows."""
85 # Create buffer for the path
86 buffer = ctypes.create_unicode_buffer(260) # MAX_PATH 1ab
87 length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer)) 1ab
89 if length == 0: 1ab
90 error_code = ctypes.GetLastError() # type: ignore[attr-defined]
91 raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})")
93 # If buffer was too small, try with larger buffer
94 if length == len(buffer): 1ab
95 buffer = ctypes.create_unicode_buffer(32768) # Extended path length
96 length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer))
97 if length == 0:
98 error_code = ctypes.GetLastError() # type: ignore[attr-defined]
99 raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})")
101 return buffer.value 1ab
104def check_if_already_loaded_from_elsewhere(desc: LibDescriptor, have_abs_path: bool) -> LoadedDL | None:
105 for dll_name in desc.windows_dlls: 1abc
106 handle = kernel32.GetModuleHandleW(dll_name) 1abc
107 if handle: 1abc
108 abs_path = abs_path_for_dynamic_library(desc.name, handle)
109 if have_abs_path and desc.requires_add_dll_directory:
110 # This is a side-effect if the pathfinder loads the library via
111 # load_with_abs_path(). To make the side-effect more deterministic,
112 # activate it even if the library was already loaded from elsewhere.
113 add_dll_directory(abs_path)
114 return LoadedDL(abs_path, True, ctypes_handle_to_unsigned_int(handle), "was-already-loaded-from-elsewhere")
115 return None 1abc
118def load_with_system_search(desc: LibDescriptor) -> LoadedDL | None:
119 """Try to load a DLL using system search paths.
121 Args:
122 libname: The name of the library to load
124 Returns:
125 A LoadedDL object if successful, None if the library cannot be loaded
126 """
127 # Reverse tabulated names to achieve new -> old search order.
128 for dll_name in reversed(desc.windows_dlls): 1abc
129 handle = kernel32.LoadLibraryExW(dll_name, None, 0) 1abc
130 if handle: 1abc
131 abs_path = abs_path_for_dynamic_library(desc.name, handle) 1ab
132 return LoadedDL(abs_path, False, ctypes_handle_to_unsigned_int(handle), "system-search") 1ab
134 return None 1ac
137def load_with_abs_path(desc: LibDescriptor, found_path: str, found_via: str | None = None) -> LoadedDL:
138 """Load a dynamic library from the given path.
140 Args:
141 desc: Descriptor for the library to load.
142 found_path: The absolute path to the DLL file.
143 found_via: Label indicating how the path was discovered.
145 Returns:
146 A LoadedDL object representing the loaded library.
148 Raises:
149 RuntimeError: If the DLL cannot be loaded.
150 """
151 if desc.requires_add_dll_directory: 1ac
152 add_dll_directory(found_path)
154 flags = WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR 1ac
155 handle = kernel32.LoadLibraryExW(found_path, None, flags) 1ac
157 if not handle: 1ac
158 error_code = ctypes.GetLastError() # type: ignore[attr-defined]
159 raise RuntimeError(f"Failed to load DLL at {found_path}: Windows error {error_code}")
161 return LoadedDL(found_path, False, ctypes_handle_to_unsigned_int(handle), found_via) 1ac