Coverage for cuda / core / checkpoint.py: 43.44%
122 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5import ctypes as _ctypes
6from collections.abc import Mapping as _Mapping
7from typing import Any as _Any
9from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return
10from cuda.core._utils.version import binding_version as _binding_version
11from cuda.core._utils.version import driver_version as _driver_version
12from cuda.core.typing import ProcessStateType as _ProcessStateType
14try:
15 from cuda.bindings import driver as _driver
16except ImportError:
17 from cuda import cuda as _driver
20_PROCESS_STATE_NAME_ATTRS: tuple[tuple[str, _ProcessStateType], ...] = (
21 ("CU_PROCESS_STATE_RUNNING", "running"),
22 ("CU_PROCESS_STATE_LOCKED", "locked"),
23 ("CU_PROCESS_STATE_CHECKPOINTED", "checkpointed"),
24 ("CU_PROCESS_STATE_FAILED", "failed"),
25)
27_REQUIRED_BINDING_ATTRS = (
28 "cuCheckpointProcessCheckpoint",
29 "cuCheckpointProcessGetRestoreThreadId",
30 "cuCheckpointProcessGetState",
31 "cuCheckpointProcessLock",
32 "cuCheckpointProcessRestore",
33 "cuCheckpointProcessUnlock",
34 "CUcheckpointGpuPair",
35 "CUcheckpointLockArgs",
36 "CUprocessState",
37 "CUcheckpointRestoreArgs",
38)
39_REQUIRED_DRIVER_VERSION = (12, 8, 0)
40_driver_capability_checked = False
43class Process:
44 """
45 CUDA process that can be locked, checkpointed, restored, and unlocked.
47 Parameters
48 ----------
49 pid : int
50 Process ID of the CUDA process.
51 """
53 __slots__ = ("_pid",)
55 def __init__(self, pid: int):
56 self._pid = _check_pid(pid) 1befcd
58 @property
59 def pid(self) -> int:
60 """
61 Process ID of the CUDA process.
62 """
63 return self._pid 1b
65 @property
66 def state(self) -> _ProcessStateType:
67 """
68 CUDA checkpoint state for this process.
69 """
70 driver = _get_driver()
71 state = _call_driver(driver, driver.cuCheckpointProcessGetState, self._pid)
72 state_names = _get_process_state_names(driver)
73 try:
74 return state_names[state]
75 except KeyError as e:
76 state_value = int(state)
77 raise RuntimeError(f"Unknown CUDA checkpoint process state: {state_value}") from e
79 @property
80 def restore_thread_id(self) -> int:
81 """
82 CUDA restore thread ID for this process.
83 """
84 driver = _get_driver()
85 return _call_driver(driver, driver.cuCheckpointProcessGetRestoreThreadId, self._pid)
87 def lock(self, timeout_ms: int = 0) -> None:
88 """
89 Lock this process, blocking further CUDA API calls.
91 Parameters
92 ----------
93 timeout_ms : int, optional
94 Timeout in milliseconds. A value of 0 indicates no timeout.
95 """
96 driver = _get_driver()
97 args = driver.CUcheckpointLockArgs()
98 args.timeoutMs = _check_timeout_ms(timeout_ms)
99 _call_driver(driver, driver.cuCheckpointProcessLock, self._pid, args)
101 def checkpoint(self) -> None:
102 """
103 Checkpoint the GPU memory contents of this locked process.
104 """
105 driver = _get_driver()
106 _call_driver(driver, driver.cuCheckpointProcessCheckpoint, self._pid, None)
108 def restore(self, gpu_mapping: _Mapping[_Any, _Any] | None = None) -> None:
109 """
110 Restore this checkpointed process.
112 Parameters
113 ----------
114 gpu_mapping : mapping, optional
115 GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID
116 to restore onto. For migration workflows, provide mappings for
117 every GPU visible to the kernel-mode driver. User-space masking
118 such as ``CUDA_VISIBLE_DEVICES`` does not reduce this mapping
119 requirement.
120 """
121 driver = _get_driver()
122 args = _make_restore_args(driver, gpu_mapping)
123 _call_driver(driver, driver.cuCheckpointProcessRestore, self._pid, args)
125 def unlock(self) -> None:
126 """
127 Unlock this locked process so it can resume CUDA API calls.
128 """
129 driver = _get_driver()
130 _call_driver(driver, driver.cuCheckpointProcessUnlock, self._pid, None)
133def _get_driver():
134 global _driver_capability_checked
135 if _driver_capability_checked:
136 return _driver
138 binding_ver = _binding_version()
139 if not _binding_version_supports_checkpoint(binding_ver):
140 raise RuntimeError(
141 "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. "
142 f"Found cuda.bindings {'.'.join(str(part) for part in binding_ver[:3])}."
143 )
145 missing = [name for name in _REQUIRED_BINDING_ATTRS if not hasattr(_driver, name)]
146 if missing:
147 raise RuntimeError(
148 f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}"
149 )
151 driver_ver = _driver_version()
152 if driver_ver < _REQUIRED_DRIVER_VERSION:
153 raise RuntimeError(
154 "CUDA checkpointing is not supported by the installed NVIDIA driver. "
155 "Upgrade to a driver version with CUDA checkpoint API support."
156 )
158 _driver_capability_checked = True
159 return _driver
162def _binding_version_supports_checkpoint(version) -> bool:
163 major, minor, patch = version[:3]
164 return (major == 12 and (minor, patch) >= (8, 0)) or (major == 13 and (minor, patch) >= (0, 2)) or major > 13
167def _get_process_state_names(driver) -> dict[_Any, _ProcessStateType]:
168 return {getattr(driver.CUprocessState, attr): state_name for attr, state_name in _PROCESS_STATE_NAME_ATTRS}
171def _call_driver(driver, func, *args):
172 try:
173 result = func(*args)
174 except RuntimeError as e:
175 if "cuCheckpointProcess" in str(e) and "not found" in str(e):
176 raise RuntimeError(
177 "CUDA checkpointing is not supported by the installed NVIDIA driver. "
178 "Upgrade to a driver version with CUDA checkpoint API support."
179 ) from e
180 raise
182 err = result[0]
183 not_supported_errors = (
184 getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None),
185 getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None),
186 )
187 if err in not_supported_errors:
188 raise RuntimeError(
189 "CUDA checkpointing is not supported by the installed NVIDIA driver. "
190 "Upgrade to a driver version with CUDA checkpoint API support."
191 )
193 return _handle_cuda_return(result)
196def _check_pid(pid: int) -> int:
197 if isinstance(pid, bool) or not isinstance(pid, int): 1befcd
198 raise TypeError("pid must be an int") 1ef
199 if pid <= 0: 1bcd
200 raise ValueError("pid must be a positive int") 1cd
201 return pid 1b
204def _check_timeout_ms(timeout_ms: int) -> int:
205 if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int):
206 raise TypeError("timeout_ms must be an int")
207 if timeout_ms < 0:
208 raise ValueError("timeout_ms must be >= 0")
209 return timeout_ms
212def _make_restore_args(driver, gpu_mapping: _Mapping[_Any, _Any] | None):
213 if gpu_mapping is None:
214 return None
215 if not isinstance(gpu_mapping, _Mapping):
216 raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID")
218 pairs = []
219 for old_uuid, new_uuid in gpu_mapping.items():
220 pair = driver.CUcheckpointGpuPair()
221 buffers = []
222 pair.oldUuid = _as_cuuuid(driver, old_uuid, buffers)
223 pair.newUuid = _as_cuuuid(driver, new_uuid, buffers)
224 pairs.append(pair)
226 if not pairs:
227 return None
229 args = driver.CUcheckpointRestoreArgs()
230 args.gpuPairs = pairs
231 args.gpuPairsCount = len(pairs)
232 return args
235def _as_cuuuid(driver, value, buffers):
236 """Convert *value* to a ``CUuuid``.
238 Accepts a ``CUuuid`` instance (returned as-is) or a UUID string in
239 the ``"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"`` format returned by
240 :attr:`Device.uuid`.
241 """
242 if isinstance(value, str):
243 raw = bytes.fromhex(value.replace("-", ""))
244 if len(raw) != 16:
245 raise ValueError(f"GPU UUID string must be 32 hex characters (with optional hyphens), got {value!r}")
246 buf = _ctypes.create_string_buffer(raw, 16)
247 buffers.append(buf)
248 return driver.CUuuid(_ctypes.addressof(buf))
249 return value
252__all__ = [
253 "Process",
254]