Coverage for cuda / pathfinder / _dynamic_libs / load_dl_windows.py: 78.57%

70 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-08 01:07 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# SPDX-License-Identifier: Apache-2.0 

3 

4from __future__ import annotations 

5 

6import ctypes 

7import ctypes.wintypes 

8import os 

9import struct 

10from typing import TYPE_CHECKING 

11 

12from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL 

13 

14if TYPE_CHECKING: 

15 from cuda.pathfinder._dynamic_libs.lib_descriptor import LibDescriptor 

16 

17# Mirrors WinBase.h (unfortunately not defined already elsewhere) 

18WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 

19WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 

20 

21POINTER_ADDRESS_SPACE = 2 ** (struct.calcsize("P") * 8) 

22 

23# Set up kernel32 functions with proper types 

24kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined] 

25 

26# GetModuleHandleW 

27kernel32.GetModuleHandleW.argtypes = [ctypes.wintypes.LPCWSTR] 

28kernel32.GetModuleHandleW.restype = ctypes.wintypes.HMODULE 

29 

30# LoadLibraryExW 

31kernel32.LoadLibraryExW.argtypes = [ 

32 ctypes.wintypes.LPCWSTR, # lpLibFileName 

33 ctypes.wintypes.HANDLE, # hFile (reserved, must be NULL) 

34 ctypes.wintypes.DWORD, # dwFlags 

35] 

36kernel32.LoadLibraryExW.restype = ctypes.wintypes.HMODULE 

37 

38# GetModuleFileNameW 

39kernel32.GetModuleFileNameW.argtypes = [ 

40 ctypes.wintypes.HMODULE, # hModule 

41 ctypes.wintypes.LPWSTR, # lpFilename 

42 ctypes.wintypes.DWORD, # nSize 

43] 

44kernel32.GetModuleFileNameW.restype = ctypes.wintypes.DWORD 

45 

46# AddDllDirectory (Windows 7+) 

47kernel32.AddDllDirectory.argtypes = [ctypes.wintypes.LPCWSTR] 

48kernel32.AddDllDirectory.restype = ctypes.c_void_p # DLL_DIRECTORY_COOKIE 

49 

50 

51def ctypes_handle_to_unsigned_int(handle: ctypes.wintypes.HMODULE) -> int: 

52 """Convert ctypes HMODULE to unsigned int.""" 

53 handle_uint = int(handle) 1abc

54 if handle_uint < 0: 1abc

55 # Convert from signed to unsigned representation 

56 handle_uint += POINTER_ADDRESS_SPACE 

57 return handle_uint 1abc

58 

59 

60def add_dll_directory(dll_abs_path: str) -> None: 

61 """Add a DLL directory to the search path and update PATH environment variable. 

62 

63 Args: 

64 dll_abs_path: Absolute path to the DLL file 

65 

66 Raises: 

67 AssertionError: If the directory containing the DLL does not exist 

68 """ 

69 dirpath = os.path.dirname(dll_abs_path) 

70 assert os.path.isdir(dirpath), dll_abs_path 

71 

72 # Add the DLL directory to the search path 

73 result = kernel32.AddDllDirectory(dirpath) 

74 if not result: 

75 # Fallback: just update PATH if AddDllDirectory fails 

76 pass 

77 

78 # Update PATH as a fallback for dependent DLL resolution 

79 curr_path = os.environ.get("PATH") 

80 os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath)) 

81 

82 

83def abs_path_for_dynamic_library(libname: str, handle: ctypes.wintypes.HMODULE) -> str: 

84 """Get the absolute path of a loaded dynamic library on Windows.""" 

85 # Create buffer for the path 

86 buffer = ctypes.create_unicode_buffer(260) # MAX_PATH 1ab

87 length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer)) 1ab

88 

89 if length == 0: 1ab

90 error_code = ctypes.GetLastError() # type: ignore[attr-defined] 

91 raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})") 

92 

93 # If buffer was too small, try with larger buffer 

94 if length == len(buffer): 1ab

95 buffer = ctypes.create_unicode_buffer(32768) # Extended path length 

96 length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer)) 

97 if length == 0: 

98 error_code = ctypes.GetLastError() # type: ignore[attr-defined] 

99 raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})") 

100 

101 return buffer.value 1ab

102 

103 

104def check_if_already_loaded_from_elsewhere(desc: LibDescriptor, have_abs_path: bool) -> LoadedDL | None: 

105 for dll_name in desc.windows_dlls: 1abc

106 handle = kernel32.GetModuleHandleW(dll_name) 1abc

107 if handle: 1abc

108 abs_path = abs_path_for_dynamic_library(desc.name, handle) 

109 if have_abs_path and desc.requires_add_dll_directory: 

110 # This is a side-effect if the pathfinder loads the library via 

111 # load_with_abs_path(). To make the side-effect more deterministic, 

112 # activate it even if the library was already loaded from elsewhere. 

113 add_dll_directory(abs_path) 

114 return LoadedDL(abs_path, True, ctypes_handle_to_unsigned_int(handle), "was-already-loaded-from-elsewhere") 

115 return None 1abc

116 

117 

118def load_with_system_search(desc: LibDescriptor) -> LoadedDL | None: 

119 """Try to load a DLL using system search paths. 

120 

121 Args: 

122 libname: The name of the library to load 

123 

124 Returns: 

125 A LoadedDL object if successful, None if the library cannot be loaded 

126 """ 

127 # Reverse tabulated names to achieve new -> old search order. 

128 for dll_name in reversed(desc.windows_dlls): 1abc

129 handle = kernel32.LoadLibraryExW(dll_name, None, 0) 1abc

130 if handle: 1abc

131 abs_path = abs_path_for_dynamic_library(desc.name, handle) 1ab

132 return LoadedDL(abs_path, False, ctypes_handle_to_unsigned_int(handle), "system-search") 1ab

133 

134 return None 1ac

135 

136 

137def load_with_abs_path(desc: LibDescriptor, found_path: str, found_via: str | None = None) -> LoadedDL: 

138 """Load a dynamic library from the given path. 

139 

140 Args: 

141 desc: Descriptor for the library to load. 

142 found_path: The absolute path to the DLL file. 

143 found_via: Label indicating how the path was discovered. 

144 

145 Returns: 

146 A LoadedDL object representing the loaded library. 

147 

148 Raises: 

149 RuntimeError: If the DLL cannot be loaded. 

150 """ 

151 if desc.requires_add_dll_directory: 1ac

152 add_dll_directory(found_path) 

153 

154 flags = WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR 1ac

155 handle = kernel32.LoadLibraryExW(found_path, None, flags) 1ac

156 

157 if not handle: 1ac

158 error_code = ctypes.GetLastError() # type: ignore[attr-defined] 

159 raise RuntimeError(f"Failed to load DLL at {found_path}: Windows error {error_code}") 

160 

161 return LoadedDL(found_path, False, ctypes_handle_to_unsigned_int(handle), found_via) 1ac