Coverage for cuda/core/checkpoint.py: 41.94%
124 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5import ctypes as _ctypes
6from collections.abc import Mapping
7from typing import Any
9from cuda.bindings import driver as _driver
10from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return
11from cuda.core._utils.version import binding_version as _binding_version
12from cuda.core._utils.version import driver_version as _driver_version
13from cuda.core.typing import ProcessStateType as _ProcessStateType
15_PROCESS_STATE_NAME_ATTRS: tuple[tuple[str, _ProcessStateType], ...] = (
16 ("CU_PROCESS_STATE_RUNNING", "running"),
17 ("CU_PROCESS_STATE_LOCKED", "locked"),
18 ("CU_PROCESS_STATE_CHECKPOINTED", "checkpointed"),
19 ("CU_PROCESS_STATE_FAILED", "failed"),
20)
22_REQUIRED_BINDING_ATTRS = (
23 "cuCheckpointProcessCheckpoint",
24 "cuCheckpointProcessGetRestoreThreadId",
25 "cuCheckpointProcessGetState",
26 "cuCheckpointProcessLock",
27 "cuCheckpointProcessRestore",
28 "cuCheckpointProcessUnlock",
29 "CUcheckpointGpuPair",
30 "CUcheckpointLockArgs",
31 "CUprocessState",
32 "CUcheckpointRestoreArgs",
33)
34_REQUIRED_DRIVER_VERSION = (12, 8, 0)
35_driver_capability_checked = False
38class Process:
39 """
40 CUDA process that can be locked, checkpointed, restored, and unlocked.
42 Parameters
43 ----------
44 pid : int
45 Process ID of the CUDA process.
46 """
48 __slots__ = ("_pid",)
50 def __init__(self, pid: int):
51 self._pid = _check_pid(pid) 1befcd
53 @property
54 def pid(self) -> int:
55 """
56 Process ID of the CUDA process.
57 """
58 return self._pid 1b
60 @property
61 def state(self) -> _ProcessStateType:
62 """
63 CUDA checkpoint state for this process.
64 """
65 driver = _get_driver()
66 state = _call_driver(driver, driver.cuCheckpointProcessGetState, self._pid)
67 state_names = _get_process_state_names(driver)
68 try:
69 return state_names[state]
70 except KeyError as e:
71 state_value = int(state)
72 raise RuntimeError(f"Unknown CUDA checkpoint process state: {state_value}") from e
74 @property
75 def restore_thread_id(self) -> int:
76 """
77 CUDA restore thread ID for this process.
78 """
79 driver = _get_driver()
80 return int(_call_driver(driver, driver.cuCheckpointProcessGetRestoreThreadId, self._pid))
82 def lock(self, timeout_ms: int = 0) -> None:
83 """
84 Lock this process, blocking further CUDA API calls.
86 Parameters
87 ----------
88 timeout_ms : int, optional
89 Timeout in milliseconds. A value of 0 indicates no timeout.
90 """
91 driver = _get_driver()
92 args = driver.CUcheckpointLockArgs()
93 args.timeoutMs = _check_timeout_ms(timeout_ms)
94 _call_driver(driver, driver.cuCheckpointProcessLock, self._pid, args)
96 def checkpoint(self) -> None:
97 """
98 Checkpoint the GPU memory contents of this locked process.
99 """
100 driver = _get_driver()
101 _call_driver(driver, driver.cuCheckpointProcessCheckpoint, self._pid, None)
103 def restore(self, gpu_mapping: Mapping[Any, Any] | None = None) -> None:
104 """
105 Restore this checkpointed process.
107 Parameters
108 ----------
109 gpu_mapping : mapping, optional
110 GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID
111 to restore onto. For migration workflows, provide mappings for
112 every GPU visible to the kernel-mode driver. User-space masking
113 such as ``CUDA_VISIBLE_DEVICES`` does not reduce this mapping
114 requirement.
115 """
116 driver = _get_driver()
117 args = _make_restore_args(driver, gpu_mapping)
118 _call_driver(driver, driver.cuCheckpointProcessRestore, self._pid, args)
120 def unlock(self) -> None:
121 """
122 Unlock this locked process so it can resume CUDA API calls.
123 """
124 driver = _get_driver()
125 _call_driver(driver, driver.cuCheckpointProcessUnlock, self._pid, None)
128def _get_driver() -> Any:
129 global _driver_capability_checked
130 if _driver_capability_checked:
131 return _driver
133 binding_ver = _binding_version()
134 if not _binding_version_supports_checkpoint(binding_ver):
135 raise RuntimeError(
136 "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. "
137 f"Found cuda.bindings {'.'.join(str(part) for part in binding_ver[:3])}."
138 )
140 missing = [name for name in _REQUIRED_BINDING_ATTRS if not hasattr(_driver, name)]
141 if missing:
142 raise RuntimeError(
143 f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}"
144 )
146 driver_ver = _driver_version()
147 if driver_ver < _REQUIRED_DRIVER_VERSION:
148 raise RuntimeError(
149 "CUDA checkpointing is not supported by the installed NVIDIA driver. "
150 "Upgrade to a driver version with CUDA checkpoint API support."
151 )
153 _driver_capability_checked = True
154 return _driver
157def _binding_version_supports_checkpoint(version: tuple[int, ...]) -> bool:
158 major, minor, patch = version[:3]
159 return (major == 12 and (minor, patch) >= (8, 0)) or (major == 13 and (minor, patch) >= (0, 2)) or major > 13
162def _get_process_state_names(driver: Any) -> dict[Any, _ProcessStateType]:
163 return {getattr(driver.CUprocessState, attr): state_name for attr, state_name in _PROCESS_STATE_NAME_ATTRS}
166def _call_driver(driver: Any, func: Any, *args: Any) -> Any:
167 try:
168 result = func(*args)
169 except RuntimeError as e:
170 if "cuCheckpointProcess" in str(e) and "not found" in str(e):
171 raise RuntimeError(
172 "CUDA checkpointing is not supported by the installed NVIDIA driver. "
173 "Upgrade to a driver version with CUDA checkpoint API support."
174 ) from e
175 raise
177 err = result[0]
178 not_supported_errors = (
179 getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None),
180 getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None),
181 )
182 if err in not_supported_errors:
183 raise RuntimeError(
184 "CUDA checkpointing is not supported by the installed NVIDIA driver. "
185 "Upgrade to a driver version with CUDA checkpoint API support."
186 )
188 return _handle_cuda_return(result)
191def _check_pid(pid: int) -> int:
192 if isinstance(pid, bool) or not isinstance(pid, int): 1befcd
193 raise TypeError("pid must be an int") 1ef
194 if pid <= 0: 1bcd
195 raise ValueError("pid must be a positive int") 1cd
196 return pid 1b
199def _check_timeout_ms(timeout_ms: int) -> int:
200 if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int):
201 raise TypeError("timeout_ms must be an int")
202 if timeout_ms < 0:
203 raise ValueError("timeout_ms must be >= 0")
204 return timeout_ms
207def _make_restore_args(driver: Any, gpu_mapping: Mapping[Any, Any] | None) -> Any:
208 if gpu_mapping is None:
209 return None
210 if not isinstance(gpu_mapping, Mapping):
211 raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID")
213 pairs = []
214 for old_uuid, new_uuid in gpu_mapping.items():
215 pair = driver.CUcheckpointGpuPair()
216 buffers: list[Any] = [] # holds ctypes string-buffer keepalives for the call below
217 pair.oldUuid = _as_cuuuid(driver, old_uuid, buffers)
218 pair.newUuid = _as_cuuuid(driver, new_uuid, buffers)
219 pairs.append(pair)
221 if not pairs:
222 return None
224 args = driver.CUcheckpointRestoreArgs()
225 args.gpuPairs = pairs
226 args.gpuPairsCount = len(pairs)
227 return args
230def _as_cuuuid(driver: Any, value: Any, buffers: list[Any]) -> Any:
231 """Convert *value* to a ``CUuuid``.
233 Accepts a ``CUuuid`` instance (returned as-is) or a UUID string in
234 the ``"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"`` format returned by
235 :attr:`Device.uuid`.
236 """
237 if isinstance(value, driver.CUuuid):
238 return value
239 if isinstance(value, str):
240 try:
241 raw = bytes.fromhex(value.replace("-", ""))
242 except ValueError:
243 raise ValueError(
244 f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}"
245 ) from None
246 if len(raw) != 16:
247 raise ValueError(f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}")
248 buf = _ctypes.create_string_buffer(raw, 16)
249 buffers.append(buf)
250 return driver.CUuuid(_ctypes.addressof(buf))
251 raise TypeError("GPU UUID values must be CUDA UUID objects or UUID strings")
254__all__ = [
255 "Process",
256]