Coverage for cuda/core/_memory/_managed_location.py: 89.19%

37 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-13 01:38 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# SPDX-License-Identifier: Apache-2.0 

3 

4from __future__ import annotations 

5 

6from dataclasses import dataclass 

7from typing import TYPE_CHECKING, Literal 

8 

9from cuda.core._utils.version import binding_version, driver_version 

10 

11if TYPE_CHECKING: 

12 from cuda.core._device import Device 

13 from cuda.core._host import Host 

14 

15_LocationKind = Literal["device", "host", "host_numa", "host_numa_current"] 

16 

17 

18@dataclass(frozen=True) 

19class _LocSpec: 

20 """Internal location record produced by :func:`_coerce_location`. 

21 

22 Carries the discriminator (``kind``) and the integer payload (``id``) 

23 that the Cython layer in ``_managed_memory_ops.pyx`` consumes when 

24 building ``CUmemLocation`` structs (CUDA 13+) or legacy device 

25 ordinals (CUDA 12). 

26 """ 

27 

28 kind: _LocationKind 

29 id: int = 0 

30 

31 

32def _reject_numa_host_on_cuda12(spec: _LocSpec) -> None: 

33 """Reject NUMA-host kinds on CUDA 12 builds at the public boundary. 

34 

35 The CUDA 12 ``cuMemPrefetchAsync`` / ``cuMemAdvise`` ABI takes a 

36 plain device ordinal (``-1`` for host), so it cannot represent a 

37 specific host NUMA node. Rather than letting the operation fail 

38 deep inside the Cython layer with ``RuntimeError``, raise a 

39 ``TypeError`` at the call boundary with actionable wording. 

40 """ 

41 # The host-NUMA kinds map to CU_MEM_LOCATION_TYPE_HOST_NUMA{,_CURRENT}, 

42 # both added in CUDA 13. Require both bindings and the runtime driver to 

43 # be 13.0+; bindings-only is insufficient (PR #2054 / #2064 precedent). 

44 if binding_version() >= (13, 0, 0) and driver_version() >= (13, 0, 0): 1heabcf

45 return 1heabcf

46 if spec.kind in ("host_numa", "host_numa_current"): 

47 raise TypeError( 

48 "Host(numa_id=...) / Host.numa_current() require both cuda-bindings 13.0+ " 

49 "and a CUDA 13+ runtime driver; use Host() instead" 

50 ) 

51 

52 

53def _coerce_location(value: Device | Host | None, *, allow_none: bool = False) -> _LocSpec | None: 

54 """Coerce :class:`Device` / :class:`Host` / ``None`` to ``_LocSpec``. 

55 

56 ``Host()``, ``Host(numa_id=N)``, and ``Host.numa_current()`` map to 

57 the corresponding NUMA-aware kinds. On a CUDA 12 build of 

58 ``cuda.core``, NUMA-host inputs are rejected with ``TypeError`` 

59 because the legacy ABI cannot represent them. 

60 """ 

61 # Local imports to avoid import cycles (Device pulls in CUDA init). 

62 from cuda.core._device import Device 1ripshelqmntuvawxjyzbcfgokA

63 from cuda.core._host import Host 1ripshelqmntuvawxjyzbcfgokA

64 

65 if isinstance(value, _LocSpec): 1ripshelqmntuvawxjyzbcfgokA

66 _reject_numa_host_on_cuda12(value) 

67 return value 

68 if isinstance(value, Device): 1ripshelqmntuvawxjyzbcfgokA

69 return _LocSpec(kind="device", id=value.device_id) 1ristuvawxjyzgkA

70 if isinstance(value, Host): 1iphelqmnajbcfgok

71 if value.is_numa_current: 1ihelajbcfgk

72 spec = _LocSpec(kind="host_numa_current") 1ha

73 _reject_numa_host_on_cuda12(spec) 1ha

74 return spec 1ha

75 if value.numa_id is not None: 1ielajbcfgk

76 spec = _LocSpec(kind="host_numa", id=value.numa_id) 1eabcf

77 _reject_numa_host_on_cuda12(spec) 1eabcf

78 return spec 1eabcf

79 return _LocSpec(kind="host") 1ilajgk

80 if value is None: 1pqmnabcgo

81 if allow_none: 1mnabcgo

82 return None 1mago

83 raise ValueError("location is required") 1nbc

84 raise TypeError(f"location must be a Device, Host, or None; got {type(value).__name__}") 1pq