Coverage for cuda / core / _memory / _peer_access_utils.pyx: 22.02%

218 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-22 01:37 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5from __future__ import annotations 

6  

7from collections.abc import Callable, Iterable, MutableSet 

8from dataclasses import dataclass 

9from typing import TYPE_CHECKING 

10  

11from cuda.bindings cimport cydriver 

12from cuda.core._memory._device_memory_resource cimport DeviceMemoryResource 

13from cuda.core._resource_handles cimport as_cu 

14from cuda.core._utils.cuda_utils cimport HANDLE_RETURN 

15from cpython.mem cimport PyMem_Malloc, PyMem_Free 

16from libcpp.vector cimport vector 

17  

18if TYPE_CHECKING: 

19 from cuda.core._device import Device 

20  

21  

22@dataclass(frozen=True) 

23class PeerAccessPlan: 

24 """Normalized peer-access target state and the driver updates it requires.""" 

25  

26 target_ids: tuple[int, ...] 

27 to_add: tuple[int, ...] 

28 to_remove: tuple[int, ...] 

29  

30  

31def normalize_peer_access_targets( 

32 owner_device_id: int, 

33 requested_devices: Iterable[object], 

34 *, 

35 resolve_device_id: Callable[[object], int], 

36) -> tuple[int, ...]: 

37 """Return sorted, unique peer device IDs, excluding the owner device.""" 

38  

39 target_ids = {resolve_device_id(device) for device in requested_devices} 1bcd

40 target_ids.discard(owner_device_id) 1bcd

41 return tuple(sorted(target_ids)) 1bcd

42  

43  

44def plan_peer_access_update( 

45 owner_device_id: int, 

46 current_peer_ids: Iterable[int], 

47 requested_devices: Iterable[object], 

48 *, 

49 resolve_device_id: Callable[[object], int], 

50 can_access_peer: Callable[[int], bool], 

51) -> PeerAccessPlan: 

52 """Compute the peer-access target state and add/remove deltas.""" 

53  

54 target_ids = normalize_peer_access_targets( 1bcd

55 owner_device_id, 

56 requested_devices, 

57 resolve_device_id=resolve_device_id, 1bcd

58 ) 

59 bad = tuple(dev_id for dev_id in target_ids if not can_access_peer(dev_id)) 1bcd

60 if bad: 1bcd

61 bad_ids = ", ".join(str(dev_id) for dev_id in bad) 1d

62 raise ValueError(f"Device {owner_device_id} cannot access peer(s): {bad_ids}") 1d

63  

64 current_ids = set(current_peer_ids) 1bc

65 target_id_set = set(target_ids) 1bc

66 return PeerAccessPlan( 1bc

67 target_ids=target_ids, 

68 to_add=tuple(sorted(target_id_set - current_ids)), 1bc

69 to_remove=tuple(sorted(current_ids - target_id_set)), 1bc

70 ) 

71  

72  

73def _resolve_peer_device_id(value): 

74 """Coerce ``Device | int`` into a device-ordinal int.""" 

75 from cuda.core._device import Device 

76  

77 return Device(value).device_id 

78  

79  

80# ---- driver-touching helpers (cdef inline, called from .pyx code) ----------- 

81  

82cdef inline tuple _query_peer_access_ids(DeviceMemoryResource mr): 

83 """Return the current peer device IDs as a sorted tuple of ints. 

84  

85 The full driver loop runs inside a single ``nogil`` block. Because 

86 ``range(total)`` ascends, the result is already sorted. 

87 """ 

88 cdef int total 

89 cdef int dev_id 

90 cdef int owner_id = mr._dev_id 

91 cdef cydriver.CUmemAccess_flags flags 

92 cdef cydriver.CUmemLocation location 

93 cdef cydriver.CUmemoryPool h_pool = as_cu(mr._h_pool) 

94 cdef vector[int] peers 

95 cdef size_t i, n 

96  

97 location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE 

98  

99 with nogil: 

100 HANDLE_RETURN(cydriver.cuDeviceGetCount(&total)) 

101 for dev_id in range(total): 

102 if dev_id == owner_id: 

103 continue 

104 location.id = dev_id 

105 HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, h_pool, &location)) 

106 if flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE: 

107 peers.push_back(dev_id) 

108  

109 n = peers.size() 

110 return tuple(peers[i] for i in range(n)) 

111  

112  

113cdef inline bint _peer_access_includes(DeviceMemoryResource mr, int dev_id): 

114 """Return True if peer access from ``dev_id`` is currently granted.""" 

115 cdef cydriver.CUmemAccess_flags flags 

116 cdef cydriver.CUmemLocation location 

117  

118 location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE 

119 location.id = dev_id 

120 with nogil: 

121 HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, as_cu(mr._h_pool), &location)) 

122 return flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE 

123  

124  

125def _set_pool_access(mr, tuple to_add, tuple to_remove): 

126 """Issue one ``cuMemPoolSetAccess`` for the given add/remove deltas. 

127  

128 The thin Python-callable layer that wraps the actual driver call: building 

129 the ``CUmemAccessDesc`` array and invoking ``cuMemPoolSetAccess`` happens 

130 in here. Tests monkeypatch this on the module to spy on real driver work 

131 without intercepting earlier no-op paths. 

132  

133 Preconditions: ``len(to_add) + len(to_remove) > 0`` (the caller is 

134 responsible for skipping empty diffs). 

135 """ 

136 cdef DeviceMemoryResource mr_typed = <DeviceMemoryResource>mr 

137 cdef size_t count = len(to_add) + len(to_remove) 

138 cdef cydriver.CUmemAccessDesc* access_desc = NULL 

139 cdef size_t i = 0 

140  

141 access_desc = <cydriver.CUmemAccessDesc*>PyMem_Malloc(count * sizeof(cydriver.CUmemAccessDesc)) 

142 if access_desc == NULL: 

143 raise MemoryError("Failed to allocate memory for access descriptors") 

144  

145 try: 

146 for dev_id in to_add: 

147 access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE 

148 access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE 

149 access_desc[i].location.id = dev_id 

150 i += 1 

151 for dev_id in to_remove: 

152 access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_NONE 

153 access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE 

154 access_desc[i].location.id = dev_id 

155 i += 1 

156  

157 with nogil: 

158 HANDLE_RETURN(cydriver.cuMemPoolSetAccess(as_cu(mr_typed._h_pool), access_desc, count)) 

159 finally: 

160 if access_desc != NULL: 

161 PyMem_Free(access_desc) 

162  

163  

164def _apply_peer_access_diff(mr, to_add, to_remove): 

165 """Apply a peer-access diff in at most one driver call. 

166  

167 Every write path on :class:`PeerAccessibleBySetProxy` and the 

168 ``peer_accessible_by`` setter routes through this function. Empty diffs 

169 short-circuit here so the driver-level helper :func:`_set_pool_access` is 

170 only invoked when there is actual work for ``cuMemPoolSetAccess`` to do. 

171 """ 

172 add_tuple = tuple(to_add) 

173 remove_tuple = tuple(to_remove) 

174 if not add_tuple and not remove_tuple: 

175 return 

176 _set_pool_access(mr, add_tuple, remove_tuple) 

177  

178  

179cpdef replace_peer_accessible_by(DeviceMemoryResource mr, devices): 

180 """Replace the full peer-access set in a single batched driver call. 

181  

182 Backs the ``mr.peer_accessible_by = [...]`` setter. Uses the same planner 

183 as the proxy's bulk ops; the only difference is that adds and removes are 

184 derived from the symmetric difference between current driver state and the 

185 requested target set. 

186 """ 

187 from cuda.core._device import Device 

188  

189 this_dev = Device(mr._dev_id) 

190 plan = plan_peer_access_update( 

191 owner_device_id=mr._dev_id, 

192 current_peer_ids=_query_peer_access_ids(mr), 

193 requested_devices=devices, 

194 resolve_device_id=_resolve_peer_device_id, 

195 can_access_peer=this_dev.can_access_peer, 

196 ) 

197 _apply_peer_access_diff(mr, plan.to_add, plan.to_remove) 

198  

199  

200# ---- Python MutableSet proxy ------------------------------------------------ 

201  

202class PeerAccessibleBySetProxy(MutableSet): 

203 """Live driver-backed view of the peer devices granted access to a memory pool. 

204  

205 Reads (``__contains__``, ``__iter__``, ``len(...)``) call ``cuMemPoolGetAccess``; 

206 writes (``add``, ``discard``, and bulk ops) call ``cuMemPoolSetAccess``. There 

207 is no in-memory mirror, so the view always reflects the current driver state 

208 and stays consistent across multiple wrappers around the same pool. 

209  

210 Iteration yields :class:`~cuda.core.Device` objects. ``add``, ``discard``, and 

211 ``__contains__`` accept either a :class:`~cuda.core.Device` or a device-ordinal 

212 ``int``; the owner device is silently ignored when supplied. 

213  

214 All bulk operations (``update``, ``|=``, ``&=``, ``-=``, ``^=``, ``clear``) 

215 issue exactly one ``cuMemPoolSetAccess`` call. This matters: peer-access 

216 transitions can take seconds per pool because every existing memory mapping 

217 is updated, so coalescing into a single driver call lets the toolkit handle 

218 the mappings in parallel. 

219 """ 

220  

221 __slots__ = ("_mr",) 

222  

223 def __init__(self, mr): 

224 self._mr = mr 

225  

226 @classmethod 

227 def _from_iterable(cls, it): 

228 # Binary set operators (&, |, -, ^) collect their result through 

229 # _from_iterable. Returning a plain set lets the user reason about 

230 # the result independently of any pool's driver state. 

231 return set(it) 

232  

233 # --- abstract MutableSet methods --- 

234  

235 def __contains__(self, value) -> bool: 

236 try: 

237 dev_id = _resolve_peer_device_id(value) 

238 except (TypeError, ValueError): 

239 return False 

240 cdef DeviceMemoryResource mr = <DeviceMemoryResource>self._mr 

241 if dev_id == mr._dev_id: 

242 return False 

243 return _peer_access_includes(mr, dev_id) 

244  

245 def __iter__(self): 

246 from cuda.core._device import Device 

247  

248 return iter(Device(dev_id) for dev_id in _query_peer_access_ids(self._mr)) 

249  

250 def __len__(self) -> int: 

251 return len(_query_peer_access_ids(self._mr)) 

252  

253 def add(self, value) -> None: 

254 """Grant peer access from ``value`` to allocations in this pool.""" 

255 dev_id = _resolve_peer_device_id(value) 

256 cdef DeviceMemoryResource mr = <DeviceMemoryResource>self._mr 

257 if dev_id == mr._dev_id: 

258 return 

259 if _peer_access_includes(mr, dev_id): 

260 return 

261 from cuda.core._device import Device 

262 if not Device(mr._dev_id).can_access_peer(dev_id): 

263 raise ValueError(f"Device {mr._dev_id} cannot access peer: {dev_id}") 

264 _apply_peer_access_diff(mr, (dev_id,), ()) 

265  

266 def discard(self, value) -> None: 

267 """Revoke peer access from ``value`` to allocations in this pool.""" 

268 try: 

269 dev_id = _resolve_peer_device_id(value) 

270 except (TypeError, ValueError): 

271 return 

272 cdef DeviceMemoryResource mr = <DeviceMemoryResource>self._mr 

273 if dev_id == mr._dev_id: 

274 return 

275 if not _peer_access_includes(mr, dev_id): 

276 return 

277 _apply_peer_access_diff(mr, (), (dev_id,)) 

278  

279 # --- bulk overrides: one driver call per op --- 

280  

281 def clear(self) -> None: 

282 """Revoke all peer access in a single driver call.""" 

283 self._apply((), _query_peer_access_ids(self._mr)) 

284  

285 def update(self, *others) -> None: 

286 """Grant peer access to every device in ``others`` in one driver call.""" 

287 to_add = [] 

288 for other in others: 

289 to_add.extend(other) 

290 if to_add: 

291 self._apply(to_add, ()) 

292  

293 def difference_update(self, *others) -> None: 

294 """Revoke peer access for every device in ``others`` in one driver call.""" 

295 revoke_ids = set() 

296 for other in others: 

297 for value in other: 

298 try: 

299 revoke_ids.add(_resolve_peer_device_id(value)) 

300 except (TypeError, ValueError): 

301 continue 

302 current = set(_query_peer_access_ids(self._mr)) 

303 to_remove = revoke_ids & current 

304 if to_remove: 

305 self._apply((), to_remove) 

306  

307 def intersection_update(self, *others) -> None: 

308 """Restrict peer access to the intersection in a single driver call.""" 

309 keep_ids = None 

310 for other in others: 

311 ids = set() 

312 for value in other: 

313 try: 

314 ids.add(_resolve_peer_device_id(value)) 

315 except (TypeError, ValueError): 

316 continue 

317 keep_ids = ids if keep_ids is None else keep_ids & ids 

318 if keep_ids is None: 

319 return # ``set.intersection_update()`` with no args is a no-op 

320 current = set(_query_peer_access_ids(self._mr)) 

321 to_remove = current - keep_ids 

322 if to_remove: 

323 self._apply((), to_remove) 

324  

325 def symmetric_difference_update(self, other) -> None: 

326 """Toggle peer access for every device in ``other`` in one driver call.""" 

327 toggle_ids = set() 

328 for value in other: 

329 try: 

330 toggle_ids.add(_resolve_peer_device_id(value)) 

331 except (TypeError, ValueError): 

332 continue 

333 current = set(_query_peer_access_ids(self._mr)) 

334 to_add = toggle_ids - current 

335 to_remove = toggle_ids & current 

336 if to_add or to_remove: 

337 self._apply(to_add, to_remove) 

338  

339 def __ior__(self, other): 

340 self.update(other) 

341 return self 

342  

343 def __iand__(self, other): 

344 self.intersection_update(other) 

345 return self 

346  

347 def __isub__(self, other): 

348 if other is self: 

349 self.clear() 

350 else: 

351 self.difference_update(other) 

352 return self 

353  

354 def __ixor__(self, other): 

355 self.symmetric_difference_update(other) 

356 return self 

357  

358 def __repr__(self) -> str: 

359 return f"PeerAccessibleBySetProxy({set(self)!r})" 

360  

361 # --- internal: route every write through one batched driver call --- 

362  

363 def _apply(self, additions, removals) -> None: 

364 """Compute the diff and issue a single ``cuMemPoolSetAccess``. 

365  

366 ``additions`` and ``removals`` are user-supplied (``Device | int``); 

367 only the owner device is filtered out. Adds are validated through 

368 :meth:`Device.can_access_peer` via :func:`plan_peer_access_update`; 

369 removals bypass that check (revoking is always permitted). 

370 """ 

371 from cuda.core._device import Device 

372  

373 cdef DeviceMemoryResource mr = <DeviceMemoryResource>self._mr 

374 owner_id = mr._dev_id 

375 owner = Device(owner_id) 

376 current = _query_peer_access_ids(mr) 

377  

378 # Plan additions through the existing helper (validates can_access_peer). 

379 plan = plan_peer_access_update( 

380 owner_device_id=owner_id, 

381 current_peer_ids=current, 

382 # union of (current set + requested adds) so the planner emits 

383 # exactly the to_add deltas for these additions, no removals. 

384 requested_devices=[*current, *additions], 

385 resolve_device_id=_resolve_peer_device_id, 

386 can_access_peer=owner.can_access_peer, 

387 ) 

388 to_add = plan.to_add 

389  

390 # Removals: resolve, drop owner and unknowns, intersect with current. 

391 current_set = set(current) 

392 revoke_ids = set() 

393 for value in removals: 

394 try: 

395 dev_id = _resolve_peer_device_id(value) 

396 except (TypeError, ValueError): 

397 continue 

398 if dev_id == owner_id: 

399 continue 

400 if dev_id in current_set: 

401 revoke_ids.add(dev_id) 

402 to_remove = tuple(sorted(revoke_ids)) 

403  

404 if not to_add and not to_remove: 

405 return 

406 _apply_peer_access_diff(mr, to_add, to_remove)