Coverage for cuda / core / experimental / _memory / _graph_memory_resource.pyx: 91%
80 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7from libc.stdint cimport intptr_t
9from cuda.bindings cimport cydriver
10from cuda.core.experimental._memory._buffer cimport Buffer, MemoryResource
11from cuda.core.experimental._stream cimport default_stream, Stream_accept, Stream
12from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN
14from functools import cache
15from typing import TYPE_CHECKING
17if TYPE_CHECKING:
18 from cuda.core.experimental._memory.buffer import DevicePointerT
20__all__ = ['GraphMemoryResource']
23cdef class GraphMemoryResourceAttributes:
24 cdef:
25 int _dev_id
27 def __init__(self, *args, **kwargs):
28 raise RuntimeError("GraphMemoryResourceAttributes cannot be instantiated directly. Please use MemoryResource APIs.")
30 @classmethod
31 def _init(cls, device_id: int):
32 cdef GraphMemoryResourceAttributes self = GraphMemoryResourceAttributes.__new__(cls)
33 self._dev_id = device_id
34 return self
36 def __repr__(self):
37 return f"{self.__class__.__name__}(%s)" % ", ".join(
38 f"{attr}={getattr(self, attr)}" for attr in dir(self)
39 if not attr.startswith("_")
40 )
42 cdef int _getattribute(self, cydriver.CUgraphMem_attribute attr_enum, void* value) except?-1:
43 with nogil:
44 HANDLE_RETURN(cydriver.cuDeviceGetGraphMemAttribute(self._dev_id, attr_enum, value))
45 return 0
47 cdef int _setattribute(self, cydriver.CUgraphMem_attribute attr_enum, void* value) except?-1:
48 with nogil:
49 HANDLE_RETURN(cydriver.cuDeviceSetGraphMemAttribute(self._dev_id, attr_enum, value))
50 return 0
52 @property
53 def reserved_mem_current(self):
54 """Current amount of backing memory allocated."""
55 cdef cydriver.cuuint64_t value
56 self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT, &value)
57 return int(value)
59 @property
60 def reserved_mem_high(self):
61 """
62 High watermark of backing memory allocated. It can be set to zero to
63 reset it to the current usage.
64 """
65 cdef cydriver.cuuint64_t value
66 self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH, &value)
67 return int(value)
69 @reserved_mem_high.setter
70 def reserved_mem_high(self, value: int):
71 if value != 0:
72 raise AttributeError(f"Attribute 'reserved_mem_high' may only be set to zero (got {value}).")
73 cdef cydriver.cuuint64_t zero = 0
74 self._setattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH, &zero)
76 @property
77 def used_mem_current(self):
78 """Current amount of memory in use."""
79 cdef cydriver.cuuint64_t value
80 self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_USED_MEM_CURRENT, &value)
81 return int(value)
83 @property
84 def used_mem_high(self):
85 """
86 High watermark of memory in use. It can be set to zero to reset it to
87 the current usage.
88 """
89 cdef cydriver.cuuint64_t value
90 self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_USED_MEM_HIGH, &value)
91 return int(value)
93 @used_mem_high.setter
94 def used_mem_high(self, value: int):
95 if value != 0:
96 raise AttributeError(f"Attribute 'used_mem_high' may only be set to zero (got {value}).")
97 cdef cydriver.cuuint64_t zero = 0
98 self._setattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_USED_MEM_HIGH, &zero)
101cdef class cyGraphMemoryResource(MemoryResource):
102 def __cinit__(self, int device_id):
103 self._dev_id = device_id
105 def allocate(self, size_t size, stream: Stream | GraphBuilder | None = None) -> Buffer:
106 """
107 Allocate a buffer of the requested size. See documentation for :obj:`~_memory.MemoryResource`.
108 """
109 stream = Stream_accept(stream) if stream is not None else default_stream()
110 return GMR_allocate(self, size, <Stream> stream)
112 def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream | GraphBuilder | None = None):
113 """
114 Deallocate a buffer of the requested size. See documentation for :obj:`~_memory.MemoryResource`.
115 """
116 stream = Stream_accept(stream) if stream is not None else default_stream()
117 return GMR_deallocate(ptr, size, <Stream> stream)
119 def close(self):
120 """No operation (provided for compatibility)."""
121 pass
123 def trim(self):
124 """Free unused memory that was cached on the specified device for use with graphs back to the OS."""
125 with nogil:
126 HANDLE_RETURN(cydriver.cuDeviceGraphMemTrim(self._dev_id))
128 @property
129 def attributes(self) -> GraphMemoryResourceAttributes:
130 """Asynchronous allocation attributes related to graphs."""
131 return GraphMemoryResourceAttributes._init(self._dev_id)
133 @property
134 def device_id(self) -> int:
135 """The associated device ordinal."""
136 return self._dev_id
138 @property
139 def is_device_accessible(self) -> bool:
140 """Return True. This memory resource provides device-accessible buffers."""
141 return True
143 @property
144 def is_host_accessible(self) -> bool:
145 """Return False. This memory resource does not provide host-accessible buffers."""
146 return False
149class GraphMemoryResource(cyGraphMemoryResource):
150 """
151 A memory resource for memory related to graphs.
153 The only supported operations are allocation, deallocation, and a limited
154 set of status queries.
156 This memory resource should be used when building graphs. Using this when
157 graphs capture is not enabled will result in a runtime error.
159 Conversely, allocating memory from a `DeviceMemoryResource` when graph
160 capturing is enabled results in a runtime error.
162 Parameters
163 ----------
164 device_id: int | Device
165 Device or Device ordinal for which a graph memory resource is obtained.
166 """
168 def __new__(cls, device_id: int | Device):
169 cdef int c_device_id = getattr(device_id, 'device_id', device_id)
170 return cls._create(c_device_id)
172 @classmethod
173 @cache
174 def _create(cls, int device_id):
175 return cyGraphMemoryResource.__new__(cls, device_id)
178# Raise an exception if the given stream is capturing.
179# A result of CU_STREAM_CAPTURE_STATUS_INVALIDATED is considered an error.
180cdef inline int check_capturing(cydriver.CUstream s) except?-1 nogil:
181 cdef cydriver.CUstreamCaptureStatus capturing
182 HANDLE_RETURN(cydriver.cuStreamIsCapturing(s, &capturing))
183 if capturing != cydriver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE:
184 raise RuntimeError("GraphMemoryResource cannot perform memory operations on "
185 "a non-capturing stream.")
188cdef inline Buffer GMR_allocate(cyGraphMemoryResource self, size_t size, Stream stream):
189 cdef cydriver.CUstream s = stream._handle
190 cdef cydriver.CUdeviceptr devptr
191 with nogil:
192 check_capturing(s)
193 HANDLE_RETURN(cydriver.cuMemAllocAsync(&devptr, size, s))
194 cdef Buffer buf = Buffer.__new__(Buffer)
195 buf._ptr = <intptr_t>(devptr)
196 buf._ptr_obj = None
197 buf._size = size
198 buf._memory_resource = self
199 buf._alloc_stream = stream
200 return buf
203cdef inline void GMR_deallocate(intptr_t ptr, size_t size, Stream stream) noexcept:
204 cdef cydriver.CUstream s = stream._handle
205 cdef cydriver.CUdeviceptr devptr = <cydriver.CUdeviceptr>ptr
206 with nogil:
207 HANDLE_RETURN(cydriver.cuMemFreeAsync(devptr, s))