Coverage for cuda / core / checkpoint.py: 43.44%

122 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-22 01:37 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5import ctypes as _ctypes 

6from collections.abc import Mapping as _Mapping 

7from typing import Any as _Any 

8 

9from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return 

10from cuda.core._utils.version import binding_version as _binding_version 

11from cuda.core._utils.version import driver_version as _driver_version 

12from cuda.core.typing import ProcessStateType as _ProcessStateType 

13 

14try: 

15 from cuda.bindings import driver as _driver 

16except ImportError: 

17 from cuda import cuda as _driver 

18 

19 

20_PROCESS_STATE_NAME_ATTRS: tuple[tuple[str, _ProcessStateType], ...] = ( 

21 ("CU_PROCESS_STATE_RUNNING", "running"), 

22 ("CU_PROCESS_STATE_LOCKED", "locked"), 

23 ("CU_PROCESS_STATE_CHECKPOINTED", "checkpointed"), 

24 ("CU_PROCESS_STATE_FAILED", "failed"), 

25) 

26 

27_REQUIRED_BINDING_ATTRS = ( 

28 "cuCheckpointProcessCheckpoint", 

29 "cuCheckpointProcessGetRestoreThreadId", 

30 "cuCheckpointProcessGetState", 

31 "cuCheckpointProcessLock", 

32 "cuCheckpointProcessRestore", 

33 "cuCheckpointProcessUnlock", 

34 "CUcheckpointGpuPair", 

35 "CUcheckpointLockArgs", 

36 "CUprocessState", 

37 "CUcheckpointRestoreArgs", 

38) 

39_REQUIRED_DRIVER_VERSION = (12, 8, 0) 

40_driver_capability_checked = False 

41 

42 

43class Process: 

44 """ 

45 CUDA process that can be locked, checkpointed, restored, and unlocked. 

46 

47 Parameters 

48 ---------- 

49 pid : int 

50 Process ID of the CUDA process. 

51 """ 

52 

53 __slots__ = ("_pid",) 

54 

55 def __init__(self, pid: int): 

56 self._pid = _check_pid(pid) 1befcd

57 

58 @property 

59 def pid(self) -> int: 

60 """ 

61 Process ID of the CUDA process. 

62 """ 

63 return self._pid 1b

64 

65 @property 

66 def state(self) -> _ProcessStateType: 

67 """ 

68 CUDA checkpoint state for this process. 

69 """ 

70 driver = _get_driver() 

71 state = _call_driver(driver, driver.cuCheckpointProcessGetState, self._pid) 

72 state_names = _get_process_state_names(driver) 

73 try: 

74 return state_names[state] 

75 except KeyError as e: 

76 state_value = int(state) 

77 raise RuntimeError(f"Unknown CUDA checkpoint process state: {state_value}") from e 

78 

79 @property 

80 def restore_thread_id(self) -> int: 

81 """ 

82 CUDA restore thread ID for this process. 

83 """ 

84 driver = _get_driver() 

85 return _call_driver(driver, driver.cuCheckpointProcessGetRestoreThreadId, self._pid) 

86 

87 def lock(self, timeout_ms: int = 0) -> None: 

88 """ 

89 Lock this process, blocking further CUDA API calls. 

90 

91 Parameters 

92 ---------- 

93 timeout_ms : int, optional 

94 Timeout in milliseconds. A value of 0 indicates no timeout. 

95 """ 

96 driver = _get_driver() 

97 args = driver.CUcheckpointLockArgs() 

98 args.timeoutMs = _check_timeout_ms(timeout_ms) 

99 _call_driver(driver, driver.cuCheckpointProcessLock, self._pid, args) 

100 

101 def checkpoint(self) -> None: 

102 """ 

103 Checkpoint the GPU memory contents of this locked process. 

104 """ 

105 driver = _get_driver() 

106 _call_driver(driver, driver.cuCheckpointProcessCheckpoint, self._pid, None) 

107 

108 def restore(self, gpu_mapping: _Mapping[_Any, _Any] | None = None) -> None: 

109 """ 

110 Restore this checkpointed process. 

111 

112 Parameters 

113 ---------- 

114 gpu_mapping : mapping, optional 

115 GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID 

116 to restore onto. For migration workflows, provide mappings for 

117 every GPU visible to the kernel-mode driver. User-space masking 

118 such as ``CUDA_VISIBLE_DEVICES`` does not reduce this mapping 

119 requirement. 

120 """ 

121 driver = _get_driver() 

122 args = _make_restore_args(driver, gpu_mapping) 

123 _call_driver(driver, driver.cuCheckpointProcessRestore, self._pid, args) 

124 

125 def unlock(self) -> None: 

126 """ 

127 Unlock this locked process so it can resume CUDA API calls. 

128 """ 

129 driver = _get_driver() 

130 _call_driver(driver, driver.cuCheckpointProcessUnlock, self._pid, None) 

131 

132 

133def _get_driver(): 

134 global _driver_capability_checked 

135 if _driver_capability_checked: 

136 return _driver 

137 

138 binding_ver = _binding_version() 

139 if not _binding_version_supports_checkpoint(binding_ver): 

140 raise RuntimeError( 

141 "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. " 

142 f"Found cuda.bindings {'.'.join(str(part) for part in binding_ver[:3])}." 

143 ) 

144 

145 missing = [name for name in _REQUIRED_BINDING_ATTRS if not hasattr(_driver, name)] 

146 if missing: 

147 raise RuntimeError( 

148 f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}" 

149 ) 

150 

151 driver_ver = _driver_version() 

152 if driver_ver < _REQUIRED_DRIVER_VERSION: 

153 raise RuntimeError( 

154 "CUDA checkpointing is not supported by the installed NVIDIA driver. " 

155 "Upgrade to a driver version with CUDA checkpoint API support." 

156 ) 

157 

158 _driver_capability_checked = True 

159 return _driver 

160 

161 

162def _binding_version_supports_checkpoint(version) -> bool: 

163 major, minor, patch = version[:3] 

164 return (major == 12 and (minor, patch) >= (8, 0)) or (major == 13 and (minor, patch) >= (0, 2)) or major > 13 

165 

166 

167def _get_process_state_names(driver) -> dict[_Any, _ProcessStateType]: 

168 return {getattr(driver.CUprocessState, attr): state_name for attr, state_name in _PROCESS_STATE_NAME_ATTRS} 

169 

170 

171def _call_driver(driver, func, *args): 

172 try: 

173 result = func(*args) 

174 except RuntimeError as e: 

175 if "cuCheckpointProcess" in str(e) and "not found" in str(e): 

176 raise RuntimeError( 

177 "CUDA checkpointing is not supported by the installed NVIDIA driver. " 

178 "Upgrade to a driver version with CUDA checkpoint API support." 

179 ) from e 

180 raise 

181 

182 err = result[0] 

183 not_supported_errors = ( 

184 getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None), 

185 getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None), 

186 ) 

187 if err in not_supported_errors: 

188 raise RuntimeError( 

189 "CUDA checkpointing is not supported by the installed NVIDIA driver. " 

190 "Upgrade to a driver version with CUDA checkpoint API support." 

191 ) 

192 

193 return _handle_cuda_return(result) 

194 

195 

196def _check_pid(pid: int) -> int: 

197 if isinstance(pid, bool) or not isinstance(pid, int): 1befcd

198 raise TypeError("pid must be an int") 1ef

199 if pid <= 0: 1bcd

200 raise ValueError("pid must be a positive int") 1cd

201 return pid 1b

202 

203 

204def _check_timeout_ms(timeout_ms: int) -> int: 

205 if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int): 

206 raise TypeError("timeout_ms must be an int") 

207 if timeout_ms < 0: 

208 raise ValueError("timeout_ms must be >= 0") 

209 return timeout_ms 

210 

211 

212def _make_restore_args(driver, gpu_mapping: _Mapping[_Any, _Any] | None): 

213 if gpu_mapping is None: 

214 return None 

215 if not isinstance(gpu_mapping, _Mapping): 

216 raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID") 

217 

218 pairs = [] 

219 for old_uuid, new_uuid in gpu_mapping.items(): 

220 pair = driver.CUcheckpointGpuPair() 

221 buffers = [] 

222 pair.oldUuid = _as_cuuuid(driver, old_uuid, buffers) 

223 pair.newUuid = _as_cuuuid(driver, new_uuid, buffers) 

224 pairs.append(pair) 

225 

226 if not pairs: 

227 return None 

228 

229 args = driver.CUcheckpointRestoreArgs() 

230 args.gpuPairs = pairs 

231 args.gpuPairsCount = len(pairs) 

232 return args 

233 

234 

235def _as_cuuuid(driver, value, buffers): 

236 """Convert *value* to a ``CUuuid``. 

237 

238 Accepts a ``CUuuid`` instance (returned as-is) or a UUID string in 

239 the ``"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"`` format returned by 

240 :attr:`Device.uuid`. 

241 """ 

242 if isinstance(value, str): 

243 raw = bytes.fromhex(value.replace("-", "")) 

244 if len(raw) != 16: 

245 raise ValueError(f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}") 

246 buf = _ctypes.create_string_buffer(raw, 16) 

247 buffers.append(buf) 

248 return driver.CUuuid(_ctypes.addressof(buf)) 

249 return value 

250 

251 

252__all__ = [ 

253 "Process", 

254]