Coverage for cuda / core / _memory / _managed_location.py: 89.19%
37 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
6from dataclasses import dataclass
7from typing import Literal
9from cuda.core._utils.version import binding_version, driver_version
11_LocationKind = Literal["device", "host", "host_numa", "host_numa_current"]
14@dataclass(frozen=True)
15class _LocSpec:
16 """Internal location record produced by :func:`_coerce_location`.
18 Carries the discriminator (``kind``) and the integer payload (``id``)
19 that the Cython layer in ``_managed_memory_ops.pyx`` consumes when
20 building ``CUmemLocation`` structs (CUDA 13+) or legacy device
21 ordinals (CUDA 12).
22 """
24 kind: _LocationKind
25 id: int = 0
28def _reject_numa_host_on_cuda12(spec: _LocSpec) -> None:
29 """Reject NUMA-host kinds on CUDA 12 builds at the public boundary.
31 The CUDA 12 ``cuMemPrefetchAsync`` / ``cuMemAdvise`` ABI takes a
32 plain device ordinal (``-1`` for host), so it cannot represent a
33 specific host NUMA node. Rather than letting the operation fail
34 deep inside the Cython layer with ``RuntimeError``, raise a
35 ``TypeError`` at the call boundary with actionable wording.
36 """
37 # The host-NUMA kinds map to CU_MEM_LOCATION_TYPE_HOST_NUMA{,_CURRENT},
38 # both added in CUDA 13. Require both bindings and the runtime driver to
39 # be 13.0+; bindings-only is insufficient (PR #2054 / #2064 precedent).
40 if binding_version() >= (13, 0, 0) and driver_version() >= (13, 0, 0): 1hdabce
41 return 1hdabce
42 if spec.kind in ("host_numa", "host_numa_current"):
43 raise TypeError(
44 "Host(numa_id=...) / Host.numa_current() require both cuda-bindings 13.0+ "
45 "and a CUDA 13+ runtime driver; use Host() instead"
46 )
49def _coerce_location(value, *, allow_none: bool = False) -> _LocSpec | None:
50 """Coerce :class:`Device` / :class:`Host` / ``None`` to ``_LocSpec``.
52 ``Host()``, ``Host(numa_id=N)``, and ``Host.numa_current()`` map to
53 the corresponding NUMA-aware kinds. On a CUDA 12 build of
54 ``cuda.core``, NUMA-host inputs are rejected with ``TypeError``
55 because the legacy ABI cannot represent them.
56 """
57 # Local imports to avoid import cycles (Device pulls in CUDA init).
58 from cuda.core._device import Device 1ripshdlqmntuvawxjyzbcefokA
59 from cuda.core._host import Host 1ripshdlqmntuvawxjyzbcefokA
61 if isinstance(value, _LocSpec): 1ripshdlqmntuvawxjyzbcefokA
62 _reject_numa_host_on_cuda12(value)
63 return value
64 if isinstance(value, Device): 1ripshdlqmntuvawxjyzbcefokA
65 return _LocSpec(kind="device", id=value.device_id) 1ristuvawxjyzfkA
66 if isinstance(value, Host): 1iphdlqmnajbcefok
67 if value.is_numa_current: 1ihdlajbcefk
68 spec = _LocSpec(kind="host_numa_current") 1ha
69 _reject_numa_host_on_cuda12(spec) 1ha
70 return spec 1ha
71 if value.numa_id is not None: 1idlajbcefk
72 spec = _LocSpec(kind="host_numa", id=value.numa_id) 1dabce
73 _reject_numa_host_on_cuda12(spec) 1dabce
74 return spec 1dabce
75 return _LocSpec(kind="host") 1ilajfk
76 if value is None: 1pqmnabcfo
77 if allow_none: 1mnabcfo
78 return None 1mafo
79 raise ValueError("location is required") 1nbc
80 raise TypeError(f"location must be a Device, Host, or None; got {type(value).__name__}") 1pq