Coverage for cuda / core / _memory / _graph_memory_resource.pyx: 82.43%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-08 01:07 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5from __future__ import annotations 

6  

7from libc.stdint cimport intptr_t 

8  

9from cuda.bindings cimport cydriver 

10from cuda.core._memory._buffer cimport Buffer, Buffer_from_deviceptr_handle, MemoryResource 

11from cuda.core._resource_handles cimport ( 

12 DevicePtrHandle, 

13 deviceptr_alloc_async, 

14 as_cu, 

15) 

16  

17from cuda.core._stream cimport default_stream, Stream_accept, Stream 

18from cuda.core._utils.cuda_utils cimport HANDLE_RETURN 

19  

20from functools import cache 

21  

22__all__ = ['GraphMemoryResource'] 

23  

24  

25cdef class GraphMemoryResourceAttributes: 

26 cdef: 

27 int _device_id 

28  

29 def __init__(self, *args, **kwargs): 

30 raise RuntimeError("GraphMemoryResourceAttributes cannot be instantiated directly. Please use MemoryResource APIs.") 

31  

32 @classmethod 

33 def _init(cls, device_id: int): 

34 cdef GraphMemoryResourceAttributes self = GraphMemoryResourceAttributes.__new__(cls) 1abc

35 self._device_id = device_id 1abc

36 return self 1abc

37  

38 def __repr__(self): 

39 return f"{self.__class__.__name__}(%s)" % ", ".join( 

40 f"{attr}={getattr(self, attr)}" for attr in dir(self) 

41 if not attr.startswith("_") 

42 ) 

43  

44 cdef int _getattribute(self, cydriver.CUgraphMem_attribute attr_enum, void* value) except?-1: 

45 with nogil: 1dabc

46 HANDLE_RETURN(cydriver.cuDeviceGetGraphMemAttribute(self._device_id, attr_enum, value)) 1abc

47 return 0 1abc

48  

49 cdef int _setattribute(self, cydriver.CUgraphMem_attribute attr_enum, void* value) except?-1: 

50 with nogil: 1abc

51 HANDLE_RETURN(cydriver.cuDeviceSetGraphMemAttribute(self._device_id, attr_enum, value)) 1abc

52 return 0 1abc

53  

54 @property 

55 def reserved_mem_current(self): 

56 """Current amount of backing memory allocated.""" 

57 cdef cydriver.cuuint64_t value 

58 self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT, &value) 1abc

59 return int(value) 1abc

60  

61 @property 

62 def reserved_mem_high(self): 

63 """ 

64 High watermark of backing memory allocated. It can be set to zero to 

65 reset it to the current usage. 

66 """ 

67 cdef cydriver.cuuint64_t value 

68 self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH, &value) 1abc

69 return int(value) 1abc

70  

71 @reserved_mem_high.setter 

72 def reserved_mem_high(self, value: int): 

73 if value != 0: 1abc

74 raise AttributeError(f"Attribute 'reserved_mem_high' may only be set to zero (got {value}).") 1abc

75 cdef cydriver.cuuint64_t zero = 0 1abc

76 self._setattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH, &zero) 1abc

77  

78 @property 

79 def used_mem_current(self): 

80 """Current amount of memory in use.""" 

81 cdef cydriver.cuuint64_t value 

82 self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_USED_MEM_CURRENT, &value) 1abc

83 return int(value) 1abc

84  

85 @property 

86 def used_mem_high(self): 

87 """ 

88 High watermark of memory in use. It can be set to zero to reset it to 

89 the current usage. 

90 """ 

91 cdef cydriver.cuuint64_t value 

92 self._getattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_USED_MEM_HIGH, &value) 1abc

93 return int(value) 1abc

94  

95 @used_mem_high.setter 

96 def used_mem_high(self, value: int): 

97 if value != 0: 1abc

98 raise AttributeError(f"Attribute 'used_mem_high' may only be set to zero (got {value}).") 1abc

99 cdef cydriver.cuuint64_t zero = 0 1abc

100 self._setattribute(cydriver.CUgraphMem_attribute.CU_GRAPH_MEM_ATTR_USED_MEM_HIGH, &zero) 1abc

101  

102  

103cdef class cyGraphMemoryResource(MemoryResource): 

104 def __cinit__(self, int device_id): 

105 self._device_id = device_id 1t

106  

107 def allocate(self, size_t size, stream: Stream | GraphBuilder | None = None) -> Buffer: 

108 """ 

109 Allocate a buffer of the requested size. See documentation for :obj:`~_memory.MemoryResource`. 

110 """ 

111 stream = Stream_accept(stream) if stream is not None else default_stream() 1efgnopqrshijklmabc

112 return GMR_allocate(self, size, <Stream> stream) 1efgnopqrshijklmabc

113  

114 def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream | GraphBuilder | None = None): 

115 """ 

116 Deallocate a buffer of the requested size. See documentation for :obj:`~_memory.MemoryResource`. 

117 """ 

118 stream = Stream_accept(stream) if stream is not None else default_stream() 

119 return GMR_deallocate(ptr, size, <Stream> stream) 

120  

121 def close(self): 

122 """No operation (provided for compatibility).""" 

123 pass 

124  

125 def trim(self): 

126 """Free unused memory that was cached on the specified device for use with graphs back to the OS.""" 

127 with nogil: 1abc

128 HANDLE_RETURN(cydriver.cuDeviceGraphMemTrim(self._device_id)) 1abc

129  

130 @property 

131 def attributes(self) -> GraphMemoryResourceAttributes: 

132 """Asynchronous allocation attributes related to graphs.""" 

133 return GraphMemoryResourceAttributes._init(self._device_id) 1dabc

134  

135 @property 

136 def device_id(self) -> int: 

137 """The associated device ordinal.""" 

138 return self._device_id 1hijklmabc

139  

140 @property 

141 def is_device_accessible(self) -> bool: 

142 """Return True. This memory resource provides device-accessible buffers.""" 

143 return True 

144  

145 @property 

146 def is_host_accessible(self) -> bool: 

147 """Return False. This memory resource does not provide host-accessible buffers.""" 

148 return False 

149  

150  

151class GraphMemoryResource(cyGraphMemoryResource): 

152 """ 

153 A memory resource for memory related to graphs. 

154  

155 The only supported operations are allocation, deallocation, and a limited 

156 set of status queries. 

157  

158 This memory resource should be used when building graphs. Using this when 

159 graphs capture is not enabled will result in a runtime error. 

160  

161 Conversely, allocating memory from a `DeviceMemoryResource` when graph 

162 capturing is enabled results in a runtime error. 

163  

164 Parameters 

165 ---------- 

166 device_id: int | Device 

167 Device or Device ordinal for which a graph memory resource is obtained. 

168 """ 

169  

170 def __new__(cls, device_id: int | Device): 

171 cdef int c_device_id = getattr(device_id, 'device_id', device_id) 1efgnuopqvrshijklmabct

172 return cls._create(c_device_id) 1efgnuopqvrshijklmabct

173  

174 @classmethod 

175 @cache 

176 def _create(cls, int device_id): 

177 return cyGraphMemoryResource.__new__(cls, device_id) 1t

178  

179  

180# Raise an exception if the given stream is capturing. 

181# A result of CU_STREAM_CAPTURE_STATUS_INVALIDATED is considered an error. 

182cdef inline int check_capturing(cydriver.CUstream s) except?-1 nogil: 

183 cdef cydriver.CUstreamCaptureStatus capturing 

184 HANDLE_RETURN(cydriver.cuStreamIsCapturing(s, &capturing)) 1efgnopqrshijklmabc

185 if capturing != cydriver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1efgnopqrshijklmabc

186 raise RuntimeError("GraphMemoryResource cannot perform memory operations on " 1efg

187 "a non-capturing stream.") 

188  

189  

190cdef inline Buffer GMR_allocate(cyGraphMemoryResource self, size_t size, Stream stream): 

191 cdef cydriver.CUstream s = as_cu(stream._h_stream) 1efgnopqrshijklmabc

192 cdef DevicePtrHandle h_ptr 

193 with nogil: 1efgnopqrshijklmabc

194 check_capturing(s) 1efgnopqrshijklmabc

195 h_ptr = deviceptr_alloc_async(size, stream._h_stream) 1efgnopqrshijklmabc

196 if not h_ptr: 1efgnopqrshijklmabc

197 raise RuntimeError("Failed to allocate memory asynchronously") 

198 return Buffer_from_deviceptr_handle(h_ptr, size, self, None) 1efgnopqrshijklmabc

199  

200  

201cdef inline void GMR_deallocate(intptr_t ptr, size_t size, Stream stream) noexcept: 

202 cdef cydriver.CUstream s = as_cu(stream._h_stream) 

203 cdef cydriver.CUdeviceptr devptr = <cydriver.CUdeviceptr>ptr 

204 with nogil: 

205 HANDLE_RETURN(cydriver.cuMemFreeAsync(devptr, s))