Coverage for cuda / core / _memory / _peer_access_utils.py: 100.00%

21 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-25 01:07 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5from __future__ import annotations 

6 

7from collections.abc import Callable, Iterable 

8from dataclasses import dataclass 

9 

10 

11@dataclass(frozen=True) 

12class PeerAccessPlan: 

13 """Normalized peer-access target state and the driver updates it requires.""" 

14 

15 target_ids: tuple[int, ...] 

16 to_add: tuple[int, ...] 

17 to_remove: tuple[int, ...] 

18 

19 

20def normalize_peer_access_targets( 

21 owner_device_id: int, 

22 requested_devices: Iterable[object], 

23 *, 

24 resolve_device_id: Callable[[object], int], 

25) -> tuple[int, ...]: 

26 """Return sorted, unique peer device IDs, excluding the owner device.""" 

27 

28 target_ids = {resolve_device_id(device) for device in requested_devices} 1bcd

29 target_ids.discard(owner_device_id) 1bcd

30 return tuple(sorted(target_ids)) 1bcd

31 

32 

33def plan_peer_access_update( 

34 owner_device_id: int, 

35 current_peer_ids: Iterable[int], 

36 requested_devices: Iterable[object], 

37 *, 

38 resolve_device_id: Callable[[object], int], 

39 can_access_peer: Callable[[int], bool], 

40) -> PeerAccessPlan: 

41 """Compute the peer-access target state and add/remove deltas.""" 

42 

43 target_ids = normalize_peer_access_targets( 1bcd

44 owner_device_id, 

45 requested_devices, 

46 resolve_device_id=resolve_device_id, 

47 ) 

48 bad = tuple(dev_id for dev_id in target_ids if not can_access_peer(dev_id)) 1bcd

49 if bad: 1bcd

50 bad_ids = ", ".join(str(dev_id) for dev_id in bad) 1d

51 raise ValueError(f"Device {owner_device_id} cannot access peer(s): {bad_ids}") 1d

52 

53 current_ids = set(current_peer_ids) 1bc

54 target_id_set = set(target_ids) 1bc

55 return PeerAccessPlan( 1bc

56 target_ids=target_ids, 

57 to_add=tuple(sorted(target_id_set - current_ids)), 

58 to_remove=tuple(sorted(current_ids - target_id_set)), 

59 )