Coverage for cuda / core / _memory / _peer_access_utils.py: 100.00%
21 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7from collections.abc import Callable, Iterable
8from dataclasses import dataclass
11@dataclass(frozen=True)
12class PeerAccessPlan:
13 """Normalized peer-access target state and the driver updates it requires."""
15 target_ids: tuple[int, ...]
16 to_add: tuple[int, ...]
17 to_remove: tuple[int, ...]
20def normalize_peer_access_targets(
21 owner_device_id: int,
22 requested_devices: Iterable[object],
23 *,
24 resolve_device_id: Callable[[object], int],
25) -> tuple[int, ...]:
26 """Return sorted, unique peer device IDs, excluding the owner device."""
28 target_ids = {resolve_device_id(device) for device in requested_devices} 1bcd
29 target_ids.discard(owner_device_id) 1bcd
30 return tuple(sorted(target_ids)) 1bcd
33def plan_peer_access_update(
34 owner_device_id: int,
35 current_peer_ids: Iterable[int],
36 requested_devices: Iterable[object],
37 *,
38 resolve_device_id: Callable[[object], int],
39 can_access_peer: Callable[[int], bool],
40) -> PeerAccessPlan:
41 """Compute the peer-access target state and add/remove deltas."""
43 target_ids = normalize_peer_access_targets( 1bcd
44 owner_device_id,
45 requested_devices,
46 resolve_device_id=resolve_device_id,
47 )
48 bad = tuple(dev_id for dev_id in target_ids if not can_access_peer(dev_id)) 1bcd
49 if bad: 1bcd
50 bad_ids = ", ".join(str(dev_id) for dev_id in bad) 1d
51 raise ValueError(f"Device {owner_device_id} cannot access peer(s): {bad_ids}") 1d
53 current_ids = set(current_peer_ids) 1bc
54 target_id_set = set(target_ids) 1bc
55 return PeerAccessPlan( 1bc
56 target_ids=target_ids,
57 to_add=tuple(sorted(target_id_set - current_ids)),
58 to_remove=tuple(sorted(current_ids - target_id_set)),
59 )