Coverage for cuda / core / _memory / _managed_buffer.py: 90.10%

101 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-22 01:37 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# SPDX-License-Identifier: Apache-2.0 

3 

4from __future__ import annotations 

5 

6from collections.abc import MutableSet 

7from typing import TYPE_CHECKING 

8 

9from cuda.core._device import Device 

10from cuda.core._host import Host 

11from cuda.core._memory._buffer import Buffer 

12from cuda.core._memory._managed_memory_ops import ( 

13 _advise_one, 

14 _do_single_discard_prefetch_py, 

15 _do_single_discard_py, 

16 _do_single_prefetch_py, 

17 _read_preferred_location_v2, 

18) 

19from cuda.core._utils.cuda_utils import driver, handle_return 

20from cuda.core._utils.version import binding_version, driver_version 

21 

22if TYPE_CHECKING: 

23 from cuda.core._memory._buffer import MemoryResource 

24 from cuda.core._stream import Stream 

25 from cuda.core.graph import GraphBuilder 

26 

27 

28_INT_SIZE = 4 

29 

30# Enum aliases — referenced once per property write, so cache the lookup. 

31_ADV = driver.CUmem_advise 

32_SET_READ_MOSTLY = _ADV.CU_MEM_ADVISE_SET_READ_MOSTLY 

33_UNSET_READ_MOSTLY = _ADV.CU_MEM_ADVISE_UNSET_READ_MOSTLY 

34_SET_PREFERRED = _ADV.CU_MEM_ADVISE_SET_PREFERRED_LOCATION 

35_UNSET_PREFERRED = _ADV.CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION 

36_SET_ACCESSED_BY = _ADV.CU_MEM_ADVISE_SET_ACCESSED_BY 

37_UNSET_ACCESSED_BY = _ADV.CU_MEM_ADVISE_UNSET_ACCESSED_BY 

38 

39_RANGE = driver.CUmem_range_attribute 

40_ATTR_READ_MOSTLY = _RANGE.CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY 

41_ATTR_PREFERRED = _RANGE.CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION 

42_ATTR_ACCESSED_BY = _RANGE.CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY 

43 

44 

45def _get_int_attr(buf: Buffer, attribute) -> int: 

46 return handle_return(driver.cuMemRangeGetAttribute(_INT_SIZE, attribute, buf.handle, buf.size)) 1qkmnlrs

47 

48 

49def _query_accessed_by(buf: Buffer) -> list[Device | Host]: 

50 """Read the live ``CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY`` list. 

51 

52 Driver fills an int32 array: device id, ``-1`` = host, ``-2`` = empty. 

53 Sized to ``cuDeviceGetCount() + 1`` (every visible device plus host). 

54 """ 

55 num_devices = handle_return(driver.cuDeviceGetCount()) 1dbc

56 n = num_devices + 1 1dbc

57 raw = handle_return(driver.cuMemRangeGetAttribute(n * _INT_SIZE, _ATTR_ACCESSED_BY, buf.handle, buf.size)) 1dbc

58 return [Host() if v == -1 else Device(v) for v in raw if v != -2] 1dbc

59 

60 

61class AccessedBySetProxy(MutableSet): 

62 """Live driver-backed view of ``set_accessed_by`` advice for a managed buffer. 

63 

64 Reads (``__contains__``, ``__iter__``, ``len(...)``) call 

65 ``cuMemRangeGetAttribute``; writes (``add``, ``discard``) call 

66 ``cuMemAdvise``. There is no in-memory mirror, so the view always 

67 reflects the current driver state. 

68 

69 Note 

70 ---- 

71 The driver returns integer device ordinals (``-1`` for host); host 

72 NUMA distinctions applied via ``Host(numa_id=...)`` collapse to a 

73 generic ``Host()`` when iterating this set. 

74 """ 

75 

76 __slots__ = ("_buf",) 

77 

78 def __init__(self, buf: ManagedBuffer): 

79 self._buf = buf 1dbcefg

80 

81 # Operators such as &|^ produce a plain set, not another proxy. 

82 @classmethod 

83 def _from_iterable(cls, it): 

84 return set(it) 1b

85 

86 # --- abstract methods required by MutableSet --- 

87 

88 def __contains__(self, location) -> bool: 

89 if not isinstance(location, (Device, Host)): 1dbc

90 return False 

91 return location in _query_accessed_by(self._buf) 1dbc

92 

93 def __iter__(self): 

94 return iter(_query_accessed_by(self._buf)) 1b

95 

96 def __len__(self) -> int: 

97 return len(_query_accessed_by(self._buf)) 1b

98 

99 def add(self, location: Device | Host) -> None: 

100 """Apply ``set_accessed_by`` advice for ``location``.""" 

101 if not isinstance(location, (Device, Host)): 1dbefg

102 raise TypeError(f"expected Device or Host, got {type(location).__name__}") 

103 _advise_one(self._buf, _SET_ACCESSED_BY, location) 1dbefg

104 

105 def discard(self, location: Device | Host) -> None: 

106 """Apply ``unset_accessed_by`` advice for ``location``. 

107 

108 Per the ``MutableSet`` contract, ``discard`` is a no-op for elements 

109 not in the set. ``set_accessed_by`` only accepts ``Device`` and the 

110 generic ``Host()`` — NUMA-aware host variants (``Host(numa_id=...)``, 

111 ``Host.numa_current()``) can never enter the set, so discarding them 

112 is silently ignored rather than forwarded to the driver. 

113 """ 

114 if not isinstance(location, (Device, Host)): 1db

115 return 

116 if isinstance(location, Host) and (location.numa_id is not None or location.is_numa_current): 1db

117 return 1b

118 _advise_one(self._buf, _UNSET_ACCESSED_BY, location) 1db

119 

120 def __repr__(self) -> str: 

121 return f"AccessedBySetProxy({set(_query_accessed_by(self._buf))!r})" 1b

122 

123 

124class ManagedBuffer(Buffer): 

125 """Managed (unified) memory buffer with a property-style advice API. 

126 

127 Returned by :meth:`ManagedMemoryResource.allocate`, or wrap an 

128 existing managed-memory pointer with :meth:`ManagedBuffer.from_handle`. 

129 

130 Examples 

131 -------- 

132 >>> buf = mr.allocate(size) 

133 >>> buf.read_mostly = True 

134 >>> buf.preferred_location = Device(0) 

135 >>> buf.accessed_by.add(Device(1)) 

136 >>> buf.prefetch(Device(0), stream=stream) 

137 

138 Note 

139 ---- 

140 On CUDA 13 builds, ``preferred_location`` round-trips full NUMA 

141 information. On CUDA 12 builds, ``Host(numa_id=...)`` and 

142 ``Host.numa_current()`` are rejected with ``TypeError`` at the call 

143 boundary — only ``Device(...)`` and the generic ``Host()`` are 

144 accepted. Use ``Host()`` to target the host on CUDA 12. 

145 """ 

146 

147 @classmethod 

148 def from_handle( 

149 cls, 

150 ptr, 

151 size: int, 

152 mr: MemoryResource | None = None, 

153 owner: object | None = None, 

154 ) -> ManagedBuffer: 

155 """Wrap an existing managed-memory pointer in a :class:`ManagedBuffer`. 

156 

157 Use this when you have an externally-allocated managed pointer 

158 and want the property-style advice API (:attr:`read_mostly`, 

159 :attr:`preferred_location`, :attr:`accessed_by`). 

160 

161 Parameters 

162 ---------- 

163 ptr : :obj:`~_memory.DevicePointerT` 

164 Pointer to a managed allocation. 

165 size : int 

166 Allocation size in bytes. 

167 mr : :obj:`~_memory.MemoryResource`, optional 

168 Memory resource that owns ``ptr``. When provided, its 

169 ``deallocate`` is called when the buffer is closed. 

170 owner : object, optional 

171 An object that keeps the underlying allocation alive. 

172 ``owner`` and ``mr`` cannot both be specified. 

173 """ 

174 return cls._init(ptr, size, mr=mr, owner=owner) 1tuvwxyzAhBCj

175 

176 @property 

177 def read_mostly(self) -> bool: 

178 """Whether ``set_read_mostly`` advice is currently applied.""" 

179 return _get_int_attr(self, _ATTR_READ_MOSTLY) != 0 1l

180 

181 @read_mostly.setter 

182 def read_mostly(self, value: bool) -> None: 

183 _advise_one(self, _SET_READ_MOSTLY if value else _UNSET_READ_MOSTLY, None) 1elj

184 

185 @property 

186 def preferred_location(self) -> Device | Host | None: 

187 """Currently applied ``set_preferred_location`` target, or ``None``. 

188 

189 On CUDA 13 builds, fully round-trips ``Host(numa_id=N)``. On CUDA 12 

190 the legacy attribute carries only a device ordinal (or ``-1`` for 

191 host), so ``Host(numa_id=N)`` set via the setter round-trips back 

192 as ``Host()``. 

193 """ 

194 # The v2 path uses CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION_{TYPE,ID}, 

195 # both added in CUDA 13. Require both bindings and the runtime driver 

196 # to be 13.0+; otherwise fall back to the legacy device-ordinal path. 

197 # See PR #2054 / #2064 for prior bindings-only-check regressions. 

198 if binding_version() >= (13, 0, 0) and driver_version() >= (13, 0, 0): 1hi

199 return _read_preferred_location_v2(self) 1hi

200 # CUDA 12 legacy path (no NUMA info available; also taken when 

201 # bindings are 13.x but the runtime driver is still 12.x). 

202 loc_id = _get_int_attr(self, _ATTR_PREFERRED) 

203 if loc_id == -2: 

204 return None 

205 if loc_id == -1: 

206 return Host() 

207 return Device(loc_id) 

208 

209 @preferred_location.setter 

210 def preferred_location(self, value: Device | Host | None) -> None: 

211 if value is None: 1ehi

212 _advise_one(self, _UNSET_PREFERRED, None) 1i

213 else: 

214 _advise_one(self, _SET_PREFERRED, value) 1ehi

215 

216 @property 

217 def accessed_by(self) -> AccessedBySetProxy: 

218 """Live set-like view of ``set_accessed_by`` locations.""" 

219 return AccessedBySetProxy(self) 1dbcefg

220 

221 @accessed_by.setter 

222 def accessed_by(self, locations) -> None: 

223 # Validate every target before issuing any cuMemAdvise so an invalid 

224 # element can't leave accessed_by partially mutated. 

225 target: set[Device | Host] = set() 1c

226 for loc in locations: 1c

227 if not isinstance(loc, (Device, Host)): 1c

228 raise TypeError(f"accessed_by entries must be Device or Host, got {type(loc).__name__}") 

229 target.add(loc) 1c

230 current = set(_query_accessed_by(self)) 1c

231 for loc in current - target: 1c

232 _advise_one(self, _UNSET_ACCESSED_BY, loc) 1c

233 for loc in target - current: 1c

234 _advise_one(self, _SET_ACCESSED_BY, loc) 1c

235 

236 def prefetch(self, location: Device | Host, *, stream: Stream | GraphBuilder) -> None: 

237 """Prefetch this range to ``location`` on ``stream``.""" 

238 _do_single_prefetch_py(self, location, stream) 1opkmnfgj

239 

240 def discard(self, *, stream: Stream | GraphBuilder) -> None: 

241 """Discard this range's resident pages on ``stream`` (CUDA 13+).""" 

242 _do_single_discard_py(self, stream) 1op

243 

244 def discard_prefetch(self, location: Device | Host, *, stream: Stream | GraphBuilder) -> None: 

245 """Discard this range and prefetch to ``location`` on ``stream`` (CUDA 13+).""" 

246 _do_single_discard_prefetch_py(self, location, stream) 1kj