Coverage for cuda/core/checkpoint.py: 41.94%

124 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-13 01:38 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5import ctypes as _ctypes 

6from collections.abc import Mapping 

7from typing import Any 

8 

9from cuda.bindings import driver as _driver 

10from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return 

11from cuda.core._utils.version import binding_version as _binding_version 

12from cuda.core._utils.version import driver_version as _driver_version 

13from cuda.core.typing import ProcessStateType as _ProcessStateType 

14 

15_PROCESS_STATE_NAME_ATTRS: tuple[tuple[str, _ProcessStateType], ...] = ( 

16 ("CU_PROCESS_STATE_RUNNING", "running"), 

17 ("CU_PROCESS_STATE_LOCKED", "locked"), 

18 ("CU_PROCESS_STATE_CHECKPOINTED", "checkpointed"), 

19 ("CU_PROCESS_STATE_FAILED", "failed"), 

20) 

21 

22_REQUIRED_BINDING_ATTRS = ( 

23 "cuCheckpointProcessCheckpoint", 

24 "cuCheckpointProcessGetRestoreThreadId", 

25 "cuCheckpointProcessGetState", 

26 "cuCheckpointProcessLock", 

27 "cuCheckpointProcessRestore", 

28 "cuCheckpointProcessUnlock", 

29 "CUcheckpointGpuPair", 

30 "CUcheckpointLockArgs", 

31 "CUprocessState", 

32 "CUcheckpointRestoreArgs", 

33) 

34_REQUIRED_DRIVER_VERSION = (12, 8, 0) 

35_driver_capability_checked = False 

36 

37 

38class Process: 

39 """ 

40 CUDA process that can be locked, checkpointed, restored, and unlocked. 

41 

42 Parameters 

43 ---------- 

44 pid : int 

45 Process ID of the CUDA process. 

46 """ 

47 

48 __slots__ = ("_pid",) 

49 

50 def __init__(self, pid: int): 

51 self._pid = _check_pid(pid) 1befcd

52 

53 @property 

54 def pid(self) -> int: 

55 """ 

56 Process ID of the CUDA process. 

57 """ 

58 return self._pid 1b

59 

60 @property 

61 def state(self) -> _ProcessStateType: 

62 """ 

63 CUDA checkpoint state for this process. 

64 """ 

65 driver = _get_driver() 

66 state = _call_driver(driver, driver.cuCheckpointProcessGetState, self._pid) 

67 state_names = _get_process_state_names(driver) 

68 try: 

69 return state_names[state] 

70 except KeyError as e: 

71 state_value = int(state) 

72 raise RuntimeError(f"Unknown CUDA checkpoint process state: {state_value}") from e 

73 

74 @property 

75 def restore_thread_id(self) -> int: 

76 """ 

77 CUDA restore thread ID for this process. 

78 """ 

79 driver = _get_driver() 

80 return int(_call_driver(driver, driver.cuCheckpointProcessGetRestoreThreadId, self._pid)) 

81 

82 def lock(self, timeout_ms: int = 0) -> None: 

83 """ 

84 Lock this process, blocking further CUDA API calls. 

85 

86 Parameters 

87 ---------- 

88 timeout_ms : int, optional 

89 Timeout in milliseconds. A value of 0 indicates no timeout. 

90 """ 

91 driver = _get_driver() 

92 args = driver.CUcheckpointLockArgs() 

93 args.timeoutMs = _check_timeout_ms(timeout_ms) 

94 _call_driver(driver, driver.cuCheckpointProcessLock, self._pid, args) 

95 

96 def checkpoint(self) -> None: 

97 """ 

98 Checkpoint the GPU memory contents of this locked process. 

99 """ 

100 driver = _get_driver() 

101 _call_driver(driver, driver.cuCheckpointProcessCheckpoint, self._pid, None) 

102 

103 def restore(self, gpu_mapping: Mapping[Any, Any] | None = None) -> None: 

104 """ 

105 Restore this checkpointed process. 

106 

107 Parameters 

108 ---------- 

109 gpu_mapping : mapping, optional 

110 GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID 

111 to restore onto. For migration workflows, provide mappings for 

112 every GPU visible to the kernel-mode driver. User-space masking 

113 such as ``CUDA_VISIBLE_DEVICES`` does not reduce this mapping 

114 requirement. 

115 """ 

116 driver = _get_driver() 

117 args = _make_restore_args(driver, gpu_mapping) 

118 _call_driver(driver, driver.cuCheckpointProcessRestore, self._pid, args) 

119 

120 def unlock(self) -> None: 

121 """ 

122 Unlock this locked process so it can resume CUDA API calls. 

123 """ 

124 driver = _get_driver() 

125 _call_driver(driver, driver.cuCheckpointProcessUnlock, self._pid, None) 

126 

127 

128def _get_driver() -> Any: 

129 global _driver_capability_checked 

130 if _driver_capability_checked: 

131 return _driver 

132 

133 binding_ver = _binding_version() 

134 if not _binding_version_supports_checkpoint(binding_ver): 

135 raise RuntimeError( 

136 "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. " 

137 f"Found cuda.bindings {'.'.join(str(part) for part in binding_ver[:3])}." 

138 ) 

139 

140 missing = [name for name in _REQUIRED_BINDING_ATTRS if not hasattr(_driver, name)] 

141 if missing: 

142 raise RuntimeError( 

143 f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}" 

144 ) 

145 

146 driver_ver = _driver_version() 

147 if driver_ver < _REQUIRED_DRIVER_VERSION: 

148 raise RuntimeError( 

149 "CUDA checkpointing is not supported by the installed NVIDIA driver. " 

150 "Upgrade to a driver version with CUDA checkpoint API support." 

151 ) 

152 

153 _driver_capability_checked = True 

154 return _driver 

155 

156 

157def _binding_version_supports_checkpoint(version: tuple[int, ...]) -> bool: 

158 major, minor, patch = version[:3] 

159 return (major == 12 and (minor, patch) >= (8, 0)) or (major == 13 and (minor, patch) >= (0, 2)) or major > 13 

160 

161 

162def _get_process_state_names(driver: Any) -> dict[Any, _ProcessStateType]: 

163 return {getattr(driver.CUprocessState, attr): state_name for attr, state_name in _PROCESS_STATE_NAME_ATTRS} 

164 

165 

166def _call_driver(driver: Any, func: Any, *args: Any) -> Any: 

167 try: 

168 result = func(*args) 

169 except RuntimeError as e: 

170 if "cuCheckpointProcess" in str(e) and "not found" in str(e): 

171 raise RuntimeError( 

172 "CUDA checkpointing is not supported by the installed NVIDIA driver. " 

173 "Upgrade to a driver version with CUDA checkpoint API support." 

174 ) from e 

175 raise 

176 

177 err = result[0] 

178 not_supported_errors = ( 

179 getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None), 

180 getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None), 

181 ) 

182 if err in not_supported_errors: 

183 raise RuntimeError( 

184 "CUDA checkpointing is not supported by the installed NVIDIA driver. " 

185 "Upgrade to a driver version with CUDA checkpoint API support." 

186 ) 

187 

188 return _handle_cuda_return(result) 

189 

190 

191def _check_pid(pid: int) -> int: 

192 if isinstance(pid, bool) or not isinstance(pid, int): 1befcd

193 raise TypeError("pid must be an int") 1ef

194 if pid <= 0: 1bcd

195 raise ValueError("pid must be a positive int") 1cd

196 return pid 1b

197 

198 

199def _check_timeout_ms(timeout_ms: int) -> int: 

200 if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int): 

201 raise TypeError("timeout_ms must be an int") 

202 if timeout_ms < 0: 

203 raise ValueError("timeout_ms must be >= 0") 

204 return timeout_ms 

205 

206 

207def _make_restore_args(driver: Any, gpu_mapping: Mapping[Any, Any] | None) -> Any: 

208 if gpu_mapping is None: 

209 return None 

210 if not isinstance(gpu_mapping, Mapping): 

211 raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID") 

212 

213 pairs = [] 

214 for old_uuid, new_uuid in gpu_mapping.items(): 

215 pair = driver.CUcheckpointGpuPair() 

216 buffers: list[Any] = [] # holds ctypes string-buffer keepalives for the call below 

217 pair.oldUuid = _as_cuuuid(driver, old_uuid, buffers) 

218 pair.newUuid = _as_cuuuid(driver, new_uuid, buffers) 

219 pairs.append(pair) 

220 

221 if not pairs: 

222 return None 

223 

224 args = driver.CUcheckpointRestoreArgs() 

225 args.gpuPairs = pairs 

226 args.gpuPairsCount = len(pairs) 

227 return args 

228 

229 

230def _as_cuuuid(driver: Any, value: Any, buffers: list[Any]) -> Any: 

231 """Convert *value* to a ``CUuuid``. 

232 

233 Accepts a ``CUuuid`` instance (returned as-is) or a UUID string in 

234 the ``"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"`` format returned by 

235 :attr:`Device.uuid`. 

236 """ 

237 if isinstance(value, driver.CUuuid): 

238 return value 

239 if isinstance(value, str): 

240 try: 

241 raw = bytes.fromhex(value.replace("-", "")) 

242 except ValueError: 

243 raise ValueError( 

244 f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}" 

245 ) from None 

246 if len(raw) != 16: 

247 raise ValueError(f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}") 

248 buf = _ctypes.create_string_buffer(raw, 16) 

249 buffers.append(buf) 

250 return driver.CUuuid(_ctypes.addressof(buf)) 

251 raise TypeError("GPU UUID values must be CUDA UUID objects or UUID strings") 

252 

253 

254__all__ = [ 

255 "Process", 

256]