Coverage for cuda / core / _memory / _peer_access_utils.pyx: 22.02%
218 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7from collections.abc import Callable, Iterable, MutableSet
8from dataclasses import dataclass
9from typing import TYPE_CHECKING
11from cuda.bindings cimport cydriver
12from cuda.core._memory._device_memory_resource cimport DeviceMemoryResource
13from cuda.core._resource_handles cimport as_cu
14from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
15from cpython.mem cimport PyMem_Malloc, PyMem_Free
16from libcpp.vector cimport vector
18if TYPE_CHECKING:
19 from cuda.core._device import Device
22@dataclass(frozen=True)
23class PeerAccessPlan:
24 """Normalized peer-access target state and the driver updates it requires."""
26 target_ids: tuple[int, ...]
27 to_add: tuple[int, ...]
28 to_remove: tuple[int, ...]
31def normalize_peer_access_targets(
32 owner_device_id: int,
33 requested_devices: Iterable[object],
34 *,
35 resolve_device_id: Callable[[object], int],
36) -> tuple[int, ...]:
37 """Return sorted, unique peer device IDs, excluding the owner device."""
39 target_ids = {resolve_device_id(device) for device in requested_devices} 1bcd
40 target_ids.discard(owner_device_id) 1bcd
41 return tuple(sorted(target_ids)) 1bcd
44def plan_peer_access_update(
45 owner_device_id: int,
46 current_peer_ids: Iterable[int],
47 requested_devices: Iterable[object],
48 *,
49 resolve_device_id: Callable[[object], int],
50 can_access_peer: Callable[[int], bool],
51) -> PeerAccessPlan:
52 """Compute the peer-access target state and add/remove deltas."""
54 target_ids = normalize_peer_access_targets( 1bcd
55 owner_device_id,
56 requested_devices,
57 resolve_device_id=resolve_device_id, 1bcd
58 )
59 bad = tuple(dev_id for dev_id in target_ids if not can_access_peer(dev_id)) 1bcd
60 if bad: 1bcd
61 bad_ids = ", ".join(str(dev_id) for dev_id in bad) 1d
62 raise ValueError(f"Device {owner_device_id} cannot access peer(s): {bad_ids}") 1d
64 current_ids = set(current_peer_ids) 1bc
65 target_id_set = set(target_ids) 1bc
66 return PeerAccessPlan( 1bc
67 target_ids=target_ids,
68 to_add=tuple(sorted(target_id_set - current_ids)), 1bc
69 to_remove=tuple(sorted(current_ids - target_id_set)), 1bc
70 )
73def _resolve_peer_device_id(value):
74 """Coerce ``Device | int`` into a device-ordinal int."""
75 from cuda.core._device import Device
77 return Device(value).device_id
80# ---- driver-touching helpers (cdef inline, called from .pyx code) -----------
82cdef inline tuple _query_peer_access_ids(DeviceMemoryResource mr):
83 """Return the current peer device IDs as a sorted tuple of ints.
85 The full driver loop runs inside a single ``nogil`` block. Because
86 ``range(total)`` ascends, the result is already sorted.
87 """
88 cdef int total
89 cdef int dev_id
90 cdef int owner_id = mr._dev_id
91 cdef cydriver.CUmemAccess_flags flags
92 cdef cydriver.CUmemLocation location
93 cdef cydriver.CUmemoryPool h_pool = as_cu(mr._h_pool)
94 cdef vector[int] peers
95 cdef size_t i, n
97 location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
99 with nogil:
100 HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))
101 for dev_id in range(total):
102 if dev_id == owner_id:
103 continue
104 location.id = dev_id
105 HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, h_pool, &location))
106 if flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE:
107 peers.push_back(dev_id)
109 n = peers.size()
110 return tuple(peers[i] for i in range(n))
113cdef inline bint _peer_access_includes(DeviceMemoryResource mr, int dev_id):
114 """Return True if peer access from ``dev_id`` is currently granted."""
115 cdef cydriver.CUmemAccess_flags flags
116 cdef cydriver.CUmemLocation location
118 location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
119 location.id = dev_id
120 with nogil:
121 HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, as_cu(mr._h_pool), &location))
122 return flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
125def _set_pool_access(mr, tuple to_add, tuple to_remove):
126 """Issue one ``cuMemPoolSetAccess`` for the given add/remove deltas.
128 The thin Python-callable layer that wraps the actual driver call: building
129 the ``CUmemAccessDesc`` array and invoking ``cuMemPoolSetAccess`` happens
130 in here. Tests monkeypatch this on the module to spy on real driver work
131 without intercepting earlier no-op paths.
133 Preconditions: ``len(to_add) + len(to_remove) > 0`` (the caller is
134 responsible for skipping empty diffs).
135 """
136 cdef DeviceMemoryResource mr_typed = <DeviceMemoryResource>mr
137 cdef size_t count = len(to_add) + len(to_remove)
138 cdef cydriver.CUmemAccessDesc* access_desc = NULL
139 cdef size_t i = 0
141 access_desc = <cydriver.CUmemAccessDesc*>PyMem_Malloc(count * sizeof(cydriver.CUmemAccessDesc))
142 if access_desc == NULL:
143 raise MemoryError("Failed to allocate memory for access descriptors")
145 try:
146 for dev_id in to_add:
147 access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
148 access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
149 access_desc[i].location.id = dev_id
150 i += 1
151 for dev_id in to_remove:
152 access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_NONE
153 access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
154 access_desc[i].location.id = dev_id
155 i += 1
157 with nogil:
158 HANDLE_RETURN(cydriver.cuMemPoolSetAccess(as_cu(mr_typed._h_pool), access_desc, count))
159 finally:
160 if access_desc != NULL:
161 PyMem_Free(access_desc)
164def _apply_peer_access_diff(mr, to_add, to_remove):
165 """Apply a peer-access diff in at most one driver call.
167 Every write path on :class:`PeerAccessibleBySetProxy` and the
168 ``peer_accessible_by`` setter routes through this function. Empty diffs
169 short-circuit here so the driver-level helper :func:`_set_pool_access` is
170 only invoked when there is actual work for ``cuMemPoolSetAccess`` to do.
171 """
172 add_tuple = tuple(to_add)
173 remove_tuple = tuple(to_remove)
174 if not add_tuple and not remove_tuple:
175 return
176 _set_pool_access(mr, add_tuple, remove_tuple)
179cpdef replace_peer_accessible_by(DeviceMemoryResource mr, devices):
180 """Replace the full peer-access set in a single batched driver call.
182 Backs the ``mr.peer_accessible_by = [...]`` setter. Uses the same planner
183 as the proxy's bulk ops; the only difference is that adds and removes are
184 derived from the symmetric difference between current driver state and the
185 requested target set.
186 """
187 from cuda.core._device import Device
189 this_dev = Device(mr._dev_id)
190 plan = plan_peer_access_update(
191 owner_device_id=mr._dev_id,
192 current_peer_ids=_query_peer_access_ids(mr),
193 requested_devices=devices,
194 resolve_device_id=_resolve_peer_device_id,
195 can_access_peer=this_dev.can_access_peer,
196 )
197 _apply_peer_access_diff(mr, plan.to_add, plan.to_remove)
200# ---- Python MutableSet proxy ------------------------------------------------
202class PeerAccessibleBySetProxy(MutableSet):
203 """Live driver-backed view of the peer devices granted access to a memory pool.
205 Reads (``__contains__``, ``__iter__``, ``len(...)``) call ``cuMemPoolGetAccess``;
206 writes (``add``, ``discard``, and bulk ops) call ``cuMemPoolSetAccess``. There
207 is no in-memory mirror, so the view always reflects the current driver state
208 and stays consistent across multiple wrappers around the same pool.
210 Iteration yields :class:`~cuda.core.Device` objects. ``add``, ``discard``, and
211 ``__contains__`` accept either a :class:`~cuda.core.Device` or a device-ordinal
212 ``int``; the owner device is silently ignored when supplied.
214 All bulk operations (``update``, ``|=``, ``&=``, ``-=``, ``^=``, ``clear``)
215 issue exactly one ``cuMemPoolSetAccess`` call. This matters: peer-access
216 transitions can take seconds per pool because every existing memory mapping
217 is updated, so coalescing into a single driver call lets the toolkit handle
218 the mappings in parallel.
219 """
221 __slots__ = ("_mr",)
223 def __init__(self, mr):
224 self._mr = mr
226 @classmethod
227 def _from_iterable(cls, it):
228 # Binary set operators (&, |, -, ^) collect their result through
229 # _from_iterable. Returning a plain set lets the user reason about
230 # the result independently of any pool's driver state.
231 return set(it)
233 # --- abstract MutableSet methods ---
235 def __contains__(self, value) -> bool:
236 try:
237 dev_id = _resolve_peer_device_id(value)
238 except (TypeError, ValueError):
239 return False
240 cdef DeviceMemoryResource mr = <DeviceMemoryResource>self._mr
241 if dev_id == mr._dev_id:
242 return False
243 return _peer_access_includes(mr, dev_id)
245 def __iter__(self):
246 from cuda.core._device import Device
248 return iter(Device(dev_id) for dev_id in _query_peer_access_ids(self._mr))
250 def __len__(self) -> int:
251 return len(_query_peer_access_ids(self._mr))
253 def add(self, value) -> None:
254 """Grant peer access from ``value`` to allocations in this pool."""
255 dev_id = _resolve_peer_device_id(value)
256 cdef DeviceMemoryResource mr = <DeviceMemoryResource>self._mr
257 if dev_id == mr._dev_id:
258 return
259 if _peer_access_includes(mr, dev_id):
260 return
261 from cuda.core._device import Device
262 if not Device(mr._dev_id).can_access_peer(dev_id):
263 raise ValueError(f"Device {mr._dev_id} cannot access peer: {dev_id}")
264 _apply_peer_access_diff(mr, (dev_id,), ())
266 def discard(self, value) -> None:
267 """Revoke peer access from ``value`` to allocations in this pool."""
268 try:
269 dev_id = _resolve_peer_device_id(value)
270 except (TypeError, ValueError):
271 return
272 cdef DeviceMemoryResource mr = <DeviceMemoryResource>self._mr
273 if dev_id == mr._dev_id:
274 return
275 if not _peer_access_includes(mr, dev_id):
276 return
277 _apply_peer_access_diff(mr, (), (dev_id,))
279 # --- bulk overrides: one driver call per op ---
281 def clear(self) -> None:
282 """Revoke all peer access in a single driver call."""
283 self._apply((), _query_peer_access_ids(self._mr))
285 def update(self, *others) -> None:
286 """Grant peer access to every device in ``others`` in one driver call."""
287 to_add = []
288 for other in others:
289 to_add.extend(other)
290 if to_add:
291 self._apply(to_add, ())
293 def difference_update(self, *others) -> None:
294 """Revoke peer access for every device in ``others`` in one driver call."""
295 revoke_ids = set()
296 for other in others:
297 for value in other:
298 try:
299 revoke_ids.add(_resolve_peer_device_id(value))
300 except (TypeError, ValueError):
301 continue
302 current = set(_query_peer_access_ids(self._mr))
303 to_remove = revoke_ids & current
304 if to_remove:
305 self._apply((), to_remove)
307 def intersection_update(self, *others) -> None:
308 """Restrict peer access to the intersection in a single driver call."""
309 keep_ids = None
310 for other in others:
311 ids = set()
312 for value in other:
313 try:
314 ids.add(_resolve_peer_device_id(value))
315 except (TypeError, ValueError):
316 continue
317 keep_ids = ids if keep_ids is None else keep_ids & ids
318 if keep_ids is None:
319 return # ``set.intersection_update()`` with no args is a no-op
320 current = set(_query_peer_access_ids(self._mr))
321 to_remove = current - keep_ids
322 if to_remove:
323 self._apply((), to_remove)
325 def symmetric_difference_update(self, other) -> None:
326 """Toggle peer access for every device in ``other`` in one driver call."""
327 toggle_ids = set()
328 for value in other:
329 try:
330 toggle_ids.add(_resolve_peer_device_id(value))
331 except (TypeError, ValueError):
332 continue
333 current = set(_query_peer_access_ids(self._mr))
334 to_add = toggle_ids - current
335 to_remove = toggle_ids & current
336 if to_add or to_remove:
337 self._apply(to_add, to_remove)
339 def __ior__(self, other):
340 self.update(other)
341 return self
343 def __iand__(self, other):
344 self.intersection_update(other)
345 return self
347 def __isub__(self, other):
348 if other is self:
349 self.clear()
350 else:
351 self.difference_update(other)
352 return self
354 def __ixor__(self, other):
355 self.symmetric_difference_update(other)
356 return self
358 def __repr__(self) -> str:
359 return f"PeerAccessibleBySetProxy({set(self)!r})"
361 # --- internal: route every write through one batched driver call ---
363 def _apply(self, additions, removals) -> None:
364 """Compute the diff and issue a single ``cuMemPoolSetAccess``.
366 ``additions`` and ``removals`` are user-supplied (``Device | int``);
367 only the owner device is filtered out. Adds are validated through
368 :meth:`Device.can_access_peer` via :func:`plan_peer_access_update`;
369 removals bypass that check (revoking is always permitted).
370 """
371 from cuda.core._device import Device
373 cdef DeviceMemoryResource mr = <DeviceMemoryResource>self._mr
374 owner_id = mr._dev_id
375 owner = Device(owner_id)
376 current = _query_peer_access_ids(mr)
378 # Plan additions through the existing helper (validates can_access_peer).
379 plan = plan_peer_access_update(
380 owner_device_id=owner_id,
381 current_peer_ids=current,
382 # union of (current set + requested adds) so the planner emits
383 # exactly the to_add deltas for these additions, no removals.
384 requested_devices=[*current, *additions],
385 resolve_device_id=_resolve_peer_device_id,
386 can_access_peer=owner.can_access_peer,
387 )
388 to_add = plan.to_add
390 # Removals: resolve, drop owner and unknowns, intersect with current.
391 current_set = set(current)
392 revoke_ids = set()
393 for value in removals:
394 try:
395 dev_id = _resolve_peer_device_id(value)
396 except (TypeError, ValueError):
397 continue
398 if dev_id == owner_id:
399 continue
400 if dev_id in current_set:
401 revoke_ids.add(dev_id)
402 to_remove = tuple(sorted(revoke_ids))
404 if not to_add and not to_remove:
405 return
406 _apply_peer_access_diff(mr, to_add, to_remove)