Coverage for cuda / core / graph / _graph_builder.pyx: 90.98%

366 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-29 01:27 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5import weakref 

6from dataclasses import dataclass 

7  

8from libc.stdint cimport intptr_t 

9  

10from cuda.bindings cimport cydriver 

11  

12from cuda.core.graph._utils cimport _attach_host_callback_to_graph 

13from cuda.core._resource_handles cimport as_cu 

14from cuda.core._stream cimport Stream 

15from cuda.core._utils.cuda_utils cimport HANDLE_RETURN 

16from cuda.core._utils.version cimport cy_binding_version, cy_driver_version 

17  

18from cuda.core._utils.cuda_utils import ( 

19 CUDAError, 

20 driver, 

21 handle_return, 

22) 

23  

24@dataclass 

25class GraphDebugPrintOptions: 

26 """Options for debug_dot_print(). 

27  

28 Attributes 

29 ---------- 

30 verbose : bool 

31 Output all debug data as if every debug flag is enabled (Default to False) 

32 runtime_types : bool 

33 Use CUDA Runtime structures for output (Default to False) 

34 kernel_node_params : bool 

35 Adds kernel parameter values to output (Default to False) 

36 memcpy_node_params : bool 

37 Adds memcpy parameter values to output (Default to False) 

38 memset_node_params : bool 

39 Adds memset parameter values to output (Default to False) 

40 host_node_params : bool 

41 Adds host parameter values to output (Default to False) 

42 event_node_params : bool 

43 Adds event parameter values to output (Default to False) 

44 ext_semas_signal_node_params : bool 

45 Adds external semaphore signal parameter values to output (Default to False) 

46 ext_semas_wait_node_params : bool 

47 Adds external semaphore wait parameter values to output (Default to False) 

48 kernel_node_attributes : bool 

49 Adds kernel node attributes to output (Default to False) 

50 handles : bool 

51 Adds node handles and every kernel function handle to output (Default to False) 

52 mem_alloc_node_params : bool 

53 Adds memory alloc parameter values to output (Default to False) 

54 mem_free_node_params : bool 

55 Adds memory free parameter values to output (Default to False) 

56 batch_mem_op_node_params : bool 

57 Adds batch mem op parameter values to output (Default to False) 

58 extra_topo_info : bool 

59 Adds edge numbering information (Default to False) 

60 conditional_node_params : bool 

61 Adds conditional node parameter values to output (Default to False) 

62  

63 """ 

64  

65 verbose: bool = False 

66 runtime_types: bool = False 

67 kernel_node_params: bool = False 

68 memcpy_node_params: bool = False 

69 memset_node_params: bool = False 

70 host_node_params: bool = False 

71 event_node_params: bool = False 

72 ext_semas_signal_node_params: bool = False 

73 ext_semas_wait_node_params: bool = False 

74 kernel_node_attributes: bool = False 

75 handles: bool = False 

76 mem_alloc_node_params: bool = False 

77 mem_free_node_params: bool = False 

78 batch_mem_op_node_params: bool = False 

79 extra_topo_info: bool = False 

80 conditional_node_params: bool = False 

81  

82 def _to_flags(self) -> int: 

83 """Convert options to CUDA driver API flags (internal use).""" 

84 flags = 0 2nba

85 if self.verbose: 2nba

86 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 2nba

87 if self.runtime_types: 2nba

88 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a

89 if self.kernel_node_params: 2nba

90 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a

91 if self.memcpy_node_params: 2nba

92 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a

93 if self.memset_node_params: 2nba

94 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a

95 if self.host_node_params: 2nba

96 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a

97 if self.event_node_params: 2nba

98 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a

99 if self.ext_semas_signal_node_params: 2nba

100 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a

101 if self.ext_semas_wait_node_params: 2nba

102 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a

103 if self.kernel_node_attributes: 2nba

104 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a

105 if self.handles: 2nba

106 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 2nba

107 if self.mem_alloc_node_params: 2nba

108 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a

109 if self.mem_free_node_params: 2nba

110 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a

111 if self.batch_mem_op_node_params: 2nba

112 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a

113 if self.extra_topo_info: 2nba

114 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a

115 if self.conditional_node_params: 2I nba

116 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a

117 return flags 2nba

118  

119  

120@dataclass 

121class GraphCompleteOptions: 

122 """Options for graph instantiation. 

123  

124 Attributes 

125 ---------- 

126 auto_free_on_launch : bool, optional 

127 Automatically free memory allocated in a graph before relaunching. (Default to False) 

128 upload_stream : Stream, optional 

129 Stream to use to automatically upload the graph after completion. (Default to None) 

130 device_launch : bool, optional 

131 Configure the graph to be launchable from the device. This flag can only 

132 be used on platforms which support unified addressing. This flag cannot be 

133 used in conjunction with auto_free_on_launch. (Default to False) 

134 use_node_priority : bool, optional 

135 Run the graph using the per-node priority attributes rather than the 

136 priority of the stream it is launched into. (Default to False) 

137  

138 """ 

139  

140 auto_free_on_launch: bool = False 

141 upload_stream: Stream | None = None 

142 device_launch: bool = False 

143 use_node_priority: bool = False 

144  

145  

146def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> "Graph": 

147 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

148 if options: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

149 flags = 0 1%'()*+,-./:;=?@LJKD

150 if options.auto_free_on_launch: 1%'()*+,-./:;=?@LJKD

151 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1%)-:?@LJKD

152 if options.upload_stream: 1%'()*+,-./:;=?@LJKD

153 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1(+/=D

154 params.hUploadStream = options.upload_stream.handle 1(+/=D

155 if options.device_launch: 1%'()*+,-./:;=?@LJKD

156 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1,D

157 if options.use_node_priority: 1%'()*+,-./:;=?@LJKD

158 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 1'*.;?@D

159 params.flags = flags 1%'()*+,-./:;=?@LJKD

160  

161 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

162 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

163 raise RuntimeError( 

164 "Instantiation failed for an unexpected reason which is described in the return value of the function." 

165 ) 

166 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

167 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.") 

168 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

169 raise RuntimeError( 

170 "Instantiation for device launch failed because the graph contained an unsupported operation." 

171 ) 

172 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

173 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.") 

174 elif ( 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

175 cy_binding_version() >= (12, 8, 0) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

176 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

177 ): 

178 raise RuntimeError("One or more conditional handles are not associated with conditional builders.") 

179 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

180 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") 

181 return graph 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

182  

183  

184class GraphBuilder: 

185 """A graph under construction by stream capture. 

186  

187 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

188 them with a specified dependency tree. It speeds up the workflow by combining the 

189 driver activities associated with CUDA kernel launches and CUDA API calls. 

190  

191 Directly creating a :obj:`~graph.GraphBuilder` is not supported due 

192 to ambiguity. New graph builders should instead be created through a 

193 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object. 

194  

195 """ 

196  

197 class _MembersNeededForFinalize: 

198 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") 

199  

200 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required): 

201 self.stream = stream_obj 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

202 self.is_stream_owner = is_stream_owner 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

203 self.graph = None 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

204 self.conditional_graph = conditional_graph 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

205 self.is_join_required = is_join_required 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

206 weakref.finalize(graph_builder_obj, self.close) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

207  

208 def close(self): 

209 if self.stream: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

210 if not self.is_join_required: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

211 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

212 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

213 # Note how this condition only occures for the primary graph builder 

214 # This is because calling cuStreamEndCapture streams that were split off of the primary 

215 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. 

216 # Therefore, it is currently a requirement that users join all split graph builders 

217 # before a graph builder can be clearly destroyed. 

218 handle_return(driver.cuStreamEndCapture(self.stream.handle)) 

219 if self.is_stream_owner: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

220 self.stream.close() 1UCA#BOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

221 self.stream = None 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

222 if self.graph: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

223 handle_return(driver.cuGraphDestroy(self.graph)) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

224 self.graph = None 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

225 self.conditional_graph = None 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

226  

227 __slots__ = ("__weakref__", "_building_ended", "_mnff") 

228  

229 def __init__(self): 

230 raise NotImplementedError( 

231 "directly creating a Graph object can be ambiguous. Please either " 

232 "call Device.create_graph_builder() or stream.create_graph_builder()" 

233 ) 

234  

235 @classmethod 

236 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): 

237 self = cls.__new__(cls) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

238 self._mnff = GraphBuilder._MembersNeededForFinalize( 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

239 self, stream, is_stream_owner, conditional_graph, is_join_required 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

240 ) 

241  

242 self._building_ended = False 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

243 return self 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

244  

245 @property 

246 def stream(self) -> Stream: 

247 """Returns the stream associated with the graph builder.""" 

248 return self._mnff.stream 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

249  

250 @property 

251 def is_join_required(self) -> bool: 

252 """Returns True if this graph builder must be joined before building is ended.""" 

253 return self._mnff.is_join_required 1ABbcdefghijklmnopqrstua

254  

255 def begin_building(self, mode="relaxed") -> GraphBuilder: 

256 """Begins the building process. 

257  

258 Build `mode` for controlling interaction with other API calls must be one of the following: 

259  

260 - `global` : Prohibit potentially unsafe operations across all streams in the process. 

261 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread. 

262 - `relaxed` : The local thread is not prohibited from potentially unsafe operations. 

263  

264 Parameters 

265 ---------- 

266 mode : str, optional 

267 Build mode to control the interaction with other API calls that are porentially unsafe. 

268 Default set to use relaxed. 

269  

270 """ 

271 if self._building_ended: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

272 raise RuntimeError("Cannot resume building after building has ended.") 1V

273 if mode not in ("global", "thread_local", "relaxed"): 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

274 raise ValueError(f"Unsupported build mode: {mode}") 1!

275 if mode == "global": 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

276 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 18923L45!

277 elif mode == "thread_local": 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz0617QWRXJKSYTZvHENM!Da

278 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 167WXKYZ!

279 elif mode == "relaxed": 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvHENM!Da

280 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvHENM!Da

281 else: 

282 raise ValueError(f"Unsupported build mode: {mode}") 

283  

284 if self._mnff.conditional_graph: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

285 handle_return( 1bcdefghijklmnopqrstuwxyzva

286 driver.cuStreamBeginCaptureToGraph( 1bcdefghijklmnopqrstuwxyzva

287 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva

288 self._mnff.conditional_graph, 1bcdefghijklmnopqrstuwxyzva

289 None, # dependencies 

290 None, # dependencyData 

291 0, # numDependencies 

292 capture_mode, 1bcdefghijklmnopqrstuwxyzva

293 ) 

294 ) 

295 else: 

296 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

297 return self 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

298  

299 @property 

300 def is_building(self) -> bool: 

301 """Returns True if the graph builder is currently building.""" 

302 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

303 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

304 return False 1#

305 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

306 return True 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

307 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: 

308 raise RuntimeError( 

309 "Build process encountered an error and has been invalidated. Build process must now be ended." 

310 ) 

311 else: 

312 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") 

313  

314 def end_building(self) -> GraphBuilder: 

315 """Ends the building process.""" 

316 if not self.is_building: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

317 raise RuntimeError("Graph builder is not building.") 

318 if self._mnff.conditional_graph: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

319 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1bcdefghijklmnopqrstuwxyzva

320 else: 

321 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

322  

323 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to 

324 # resume the build process after the first call to end_building() 

325 self._building_ended = True 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

326 return self 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da

327  

328 def complete(self, options: GraphCompleteOptions | None = None) -> "Graph": 

329 """Completes the graph builder and returns the built :obj:`~graph.Graph` object. 

330  

331 Parameters 

332 ---------- 

333 options : :obj:`~graph.GraphCompleteOptions`, optional 

334 Customizable dataclass for the graph builder completion options. 

335  

336 Returns 

337 ------- 

338 graph : :obj:`~graph.Graph` 

339 The newly built graph. 

340  

341 """ 

342 if not self._building_ended: 1FGUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENMD

343 raise RuntimeError("Graph has not finished building.") 1U

344  

345 return _instantiate_graph(self._mnff.graph, options) 1FGUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENMD

346  

347 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): 

348 """Generates a DOT debug file for the graph builder. 

349  

350 Parameters 

351 ---------- 

352 path : str 

353 File path to use for writting debug DOT output 

354 options : :obj:`~graph.GraphDebugPrintOptions`, optional 

355 Customizable dataclass for the debug print options. 

356  

357 """ 

358 if not self._building_ended: 1a

359 raise RuntimeError("Graph has not finished building.") 

360 flags = options._to_flags() if options else 0 1a

361 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) 1a

362  

363 def split(self, count: int) -> tuple[GraphBuilder, ...]: 

364 """Splits the original graph builder into multiple graph builders. 

365  

366 The new builders inherit work dependencies from the original builder. 

367 The original builder is reused for the split and is returned first in the tuple. 

368  

369 Parameters 

370 ---------- 

371 count : int 

372 The number of graph builders to split the graph builder into. 

373  

374 Returns 

375 ------- 

376 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...] 

377 A tuple of split graph builders. The first graph builder in the tuple 

378 is always the original graph builder. 

379  

380 """ 

381 if count < 2: 1ABbcdefghijklmnopqrstua

382 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1A

383  

384 event = self._mnff.stream.record() 1ABbcdefghijklmnopqrstua

385 result = [self] 1ABbcdefghijklmnopqrstua

386 for i in range(count - 1): 1ABbcdefghijklmnopqrstua

387 stream = self._mnff.stream.device.create_stream() 1ABbcdefghijklmnopqrstua

388 stream.wait(event) 1ABbcdefghijklmnopqrstua

389 result.append( 1ABbcdefghijklmnopqrstua

390 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) 1ABbcdefghijklmnopqrstua

391 ) 

392 event.close() 1ABbcdefghijklmnopqrstua

393 return tuple(result) 1ABbcdefghijklmnopqrstua

394  

395 @staticmethod 

396 def join(*graph_builders) -> GraphBuilder: 

397 """Joins multiple graph builders into a single graph builder. 

398  

399 The returned builder inherits work dependencies from the provided builders. 

400  

401 Parameters 

402 ---------- 

403 *graph_builders : :obj:`~graph.GraphBuilder` 

404 The graph builders to join. 

405  

406 Returns 

407 ------- 

408 graph_builder : :obj:`~graph.GraphBuilder` 

409 The newly joined graph builder. 

410  

411 """ 

412 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1ABbcdefghijklmnopqrstua

413 raise TypeError("All arguments must be GraphBuilder instances") 

414 if len(graph_builders) < 2: 1ABbcdefghijklmnopqrstua

415 raise ValueError("Must join with at least two graph builders") 1A

416  

417 # Discover the root builder others should join 

418 root_idx = 0 1ABbcdefghijklmnopqrstua

419 for i, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua

420 if not builder.is_join_required: 1ABbcdefghijklmnopqrstua

421 root_idx = i 1ABbcdefghijklmnopqrstua

422 break 1ABbcdefghijklmnopqrstua

423  

424 # Join all onto the root builder 

425 root_bdr = graph_builders[root_idx] 1ABbcdefghijklmnopqrstua

426 for idx, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua

427 if idx == root_idx: 1ABbcdefghijklmnopqrstua

428 continue 1ABbcdefghijklmnopqrstua

429 root_bdr.stream.wait(builder.stream) 1ABbcdefghijklmnopqrstua

430 builder.close() 1ABbcdefghijklmnopqrstua

431  

432 return root_bdr 1ABbcdefghijklmnopqrstua

433  

434 def __cuda_stream__(self) -> tuple[int, int]: 

435 """Return an instance of a __cuda_stream__ protocol.""" 

436 return self.stream.__cuda_stream__() 

437  

438 def _get_conditional_context(self) -> driver.CUcontext: 

439 return self._mnff.stream.context.handle 1bcdefghijklmnopqrstuwxyzva

440  

441 def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle: 

442 """Creates a conditional handle for the graph builder. 

443  

444 Parameters 

445 ---------- 

446 default_value : int, optional 

447 The default value to assign to the conditional handle. 

448  

449 Returns 

450 ------- 

451 handle : driver.CUgraphConditionalHandle 

452 The newly created conditional handle. 

453  

454 """ 

455 if cy_driver_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva

456 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles") 

457 if cy_binding_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva

458 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles") 

459 if default_value is not None: 1bcdefghijklmnopqrstuwxyzva

460 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wxyzv

461 else: 

462 default_value = 0 1bcdefghijklmnopqrstua

463 flags = 0 1bcdefghijklmnopqrstua

464  

465 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1bcdefghijklmnopqrstuwxyzva

466 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva

467 raise RuntimeError("Cannot create a conditional handle when graph is not being built") 

468  

469 return handle_return( 1bcdefghijklmnopqrstuwxyzva

470 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 1bcdefghijklmnopqrstuwxyzva

471 ) 

472  

473 def _cond_with_params(self, node_params) -> tuple: 

474 # Get current capture info to ensure we're in a valid state 

475 status, _, graph, *deps_info, num_dependencies = handle_return( 1bcdefghijklmnopqrstuwxyzva

476 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) 1bcdefghijklmnopqrstuwxyzva

477 ) 

478 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva

479 raise RuntimeError("Cannot add conditional node when not actively capturing") 

480  

481 # Add the conditional node to the graph 

482 deps_info_update = [ 1bcdefghijklmnopqrstuwxyzva

483 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 1bcdefghijklmnopqrstuwxyzva

484 ] + [None] * (len(deps_info) - 1) 1bcdefghijklmnopqrstuwxyzva

485  

486 # Update the stream's capture dependencies 

487 handle_return( 1bcdefghijklmnopqrstuwxyzva

488 driver.cuStreamUpdateCaptureDependencies( 1bcdefghijklmnopqrstuwxyzva

489 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva

490 *deps_info_update, # dependencies, edgeData 1bcdefghijklmnopqrstuwxyzva

491 1, # numDependencies 

492 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1bcdefghijklmnopqrstuwxyzva

493 ) 

494 ) 

495  

496 # Create new graph builders for each condition 

497 return tuple( 1bcdefghijklmnopqrstuwxyzva

498 [ 1bcdefghijklmnopqrstuwxyzva

499 GraphBuilder._init( 1bcdefghijklmnopqrstuwxyzva

500 stream=self._mnff.stream.device.create_stream(), 1bcdefghijklmnopqrstuwxyzva

501 is_stream_owner=True, 

502 conditional_graph=node_params.conditional.phGraph_out[i], 1bcdefghijklmnopqrstuwxyzva

503 is_join_required=False, 1bcdefghijklmnopqrstuwxyzva

504 ) 

505 for i in range(node_params.conditional.size) 1bcdefghijklmnopqrstuwxyzva

506 ] 

507 ) 

508  

509 def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: 

510 """Adds an if condition branch and returns a new graph builder for it. 

511  

512 The resulting if graph will only execute the branch if the conditional 

513 handle evaluates to true at runtime. 

514  

515 The new builder inherits work dependencies from the original builder. 

516  

517 Parameters 

518 ---------- 

519 handle : driver.CUgraphConditionalHandle 

520 The handle to use for the if conditional. 

521  

522 Returns 

523 ------- 

524 graph_builder : :obj:`~graph.GraphBuilder` 

525 The newly created conditional graph builder. 

526  

527 """ 

528 if cy_driver_version() < (12, 3, 0): 1bcdefghia

529 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if") 

530 if cy_binding_version() < (12, 3, 0): 1bcdefghia

531 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if") 

532 node_params = driver.CUgraphNodeParams() 1bcdefghia

533 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia

534 node_params.conditional.handle = handle 1bcdefghia

535 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia

536 node_params.conditional.size = 1 1bcdefghia

537 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia

538 return self._cond_with_params(node_params)[0] 1bcdefghia

539  

540 def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]: 

541 """Adds an if-else condition branch and returns new graph builders for both branches. 

542  

543 The resulting if graph will execute the branch if the conditional handle 

544 evaluates to true at runtime, otherwise the else branch will execute. 

545  

546 The new builders inherit work dependencies from the original builder. 

547  

548 Parameters 

549 ---------- 

550 handle : driver.CUgraphConditionalHandle 

551 The handle to use for the if-else conditional. 

552  

553 Returns 

554 ------- 

555 graph_builders : tuple[:obj:`~graph.GraphBuilder`, :obj:`~graph.GraphBuilder`] 

556 A tuple of two new graph builders, one for the if branch and one for the else branch. 

557  

558 """ 

559 if cy_driver_version() < (12, 8, 0): 1jklmnopq

560 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else") 

561 if cy_binding_version() < (12, 8, 0): 1jklmnopq

562 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else") 

563 node_params = driver.CUgraphNodeParams() 1jklmnopq

564 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq

565 node_params.conditional.handle = handle 1jklmnopq

566 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq

567 node_params.conditional.size = 2 1jklmnopq

568 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq

569 return self._cond_with_params(node_params) 1jklmnopq

570  

571 def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]: 

572 """Adds a switch condition branch and returns new graph builders for all cases. 

573  

574 The resulting switch graph will execute the branch that matches the 

575 case index of the conditional handle at runtime. If no match is found, no branch 

576 will be executed. 

577  

578 The new builders inherit work dependencies from the original builder. 

579  

580 Parameters 

581 ---------- 

582 handle : driver.CUgraphConditionalHandle 

583 The handle to use for the switch conditional. 

584 count : int 

585 The number of cases to add to the switch conditional. 

586  

587 Returns 

588 ------- 

589 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...] 

590 A tuple of new graph builders, one for each branch. 

591  

592 """ 

593 if cy_driver_version() < (12, 8, 0): 1rstuv

594 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch") 

595 if cy_binding_version() < (12, 8, 0): 1rstuv

596 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch") 

597 node_params = driver.CUgraphNodeParams() 1rstuv

598 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1rstuv

599 node_params.conditional.handle = handle 1rstuv

600 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1rstuv

601 node_params.conditional.size = count 1rstuv

602 node_params.conditional.ctx = self._get_conditional_context() 1rstuv

603 return self._cond_with_params(node_params) 1rstuv

604  

605 def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: 

606 """Adds a while loop and returns a new graph builder for it. 

607  

608 The resulting while loop graph will execute the branch repeatedly at runtime 

609 until the conditional handle evaluates to false. 

610  

611 The new builder inherits work dependencies from the original builder. 

612  

613 Parameters 

614 ---------- 

615 handle : driver.CUgraphConditionalHandle 

616 The handle to use for the while loop. 

617  

618 Returns 

619 ------- 

620 graph_builder : :obj:`~graph.GraphBuilder` 

621 The newly created while loop graph builder. 

622  

623 """ 

624 if cy_driver_version() < (12, 3, 0): 1wxyz

625 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop") 

626 if cy_binding_version() < (12, 3, 0): 1wxyz

627 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop") 

628 node_params = driver.CUgraphNodeParams() 1wxyz

629 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wxyz

630 node_params.conditional.handle = handle 1wxyz

631 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1wxyz

632 node_params.conditional.size = 1 1wxyz

633 node_params.conditional.ctx = self._get_conditional_context() 1wxyz

634 return self._cond_with_params(node_params)[0] 1wxyz

635  

636 def close(self): 

637 """Destroy the graph builder. 

638  

639 Closes the associated stream if we own it. Borrowed stream 

640 object will instead have their references released. 

641  

642 """ 

643 self._mnff.close() 1ABPbcdefghijklmnopqrstua

644  

645 def add_child(self, child_graph: GraphBuilder): 

646 """Adds the child :obj:`~graph.GraphBuilder` builder into self. 

647  

648 The child graph builder will be added as a child node to the parent graph builder. 

649  

650 Parameters 

651 ---------- 

652 child_graph : :obj:`~graph.GraphBuilder` 

653 The child graph builder. Must have finished building. 

654 """ 

655 if not child_graph._building_ended: 1C

656 raise ValueError("Child graph has not finished building.") 

657  

658 if not self.is_building: 1C

659 raise ValueError("Parent graph is not being built.") 

660  

661 stream_handle = self._mnff.stream.handle 1C

662 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1C

663 driver.cuStreamGetCaptureInfo(stream_handle) 1C

664 ) 

665  

666 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159 

667 # for rationale 

668 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1C

669 deps_info_update = [ 1C

670 [ 1C

671 handle_return( 1C

672 driver.cuGraphAddChildGraphNode( 1C

673 graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph 1C

674 ) 

675 ) 

676 ] 

677 ] + [None] * (len(deps_info_out) - 1) 1C

678 handle_return( 1C

679 driver.cuStreamUpdateCaptureDependencies( 1C

680 stream_handle, 1C

681 *deps_info_update, # dependencies, edgeData 

682 1, 

683 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1C

684 ) 

685 ) 

686  

687 def callback(self, fn, *, user_data=None): 

688 """Add a host callback to the graph during stream capture. 

689  

690 The callback runs on the host CPU when the graph reaches this point 

691 in execution. Two modes are supported: 

692  

693 - **Python callable**: Pass any callable. The GIL is acquired 

694 automatically. The callable must take no arguments; use closures 

695 or ``functools.partial`` to bind state. 

696 - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance. 

697 The function receives a single ``void*`` argument (the 

698 ``user_data``). The caller must keep the ctypes wrapper alive 

699 for the lifetime of the graph. 

700  

701 .. warning:: 

702  

703 Callbacks must not call CUDA API functions. Doing so may 

704 deadlock or corrupt driver state. 

705  

706 Parameters 

707 ---------- 

708 fn : callable or ctypes function pointer 

709 The callback function. 

710 user_data : int or bytes-like, optional 

711 Only for ctypes function pointers. If ``int``, passed as a raw 

712 pointer (caller manages lifetime). If bytes-like, the data is 

713 copied and its lifetime is tied to the graph. 

714 """ 

715 cdef Stream stream = <Stream>self._mnff.stream 1FG

716 cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) 1FG

717 cdef cydriver.CUstreamCaptureStatus capture_status 

718 cdef cydriver.CUgraph c_graph = NULL 1FG

719  

720 with nogil: 1FG

721 IF CUDA_CORE_BUILD_MAJOR >= 13: 

722 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 1FG

723 c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL)) 

724 ELSE: 

725 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 

726 c_stream, &capture_status, NULL, &c_graph, NULL, NULL)) 

727  

728 if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1FG

729 raise RuntimeError("Cannot add callback when graph is not being built") 

730  

731 cdef cydriver.CUhostFn c_fn 

732 cdef void* c_user_data = NULL 1FG

733 _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) 1FG

734  

735 with nogil: 1FG

736 HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) 1FG

737  

738  

739class Graph: 

740 """An executable graph. 

741  

742 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

743 them with a specified dependency tree. It speeds up the workflow by combining the 

744 driver activities associated with CUDA kernel launches and CUDA API calls. 

745  

746 Graphs must be built using a :obj:`~graph.GraphBuilder` object. 

747  

748 """ 

749  

750 class _MembersNeededForFinalize: 

751 __slots__ = "graph" 

752  

753 def __init__(self, graph_obj, graph): 

754 self.graph = graph 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

755 weakref.finalize(graph_obj, self.close) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

756  

757 def close(self): 

758 if self.graph: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

759 handle_return(driver.cuGraphExecDestroy(self.graph)) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

760 self.graph = None 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

761  

762 __slots__ = ("__weakref__", "_mnff") 

763  

764 def __init__(self): 

765 raise RuntimeError("directly constructing a Graph instance is not supported") 

766  

767 @classmethod 

768 def _init(cls, graph): 

769 self = cls.__new__(cls) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

770 self._mnff = Graph._MembersNeededForFinalize(self, graph) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

771 return self 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D

772  

773 def close(self): 

774 """Destroy the graph.""" 

775 self._mnff.close() 1PD

776  

777 @property 

778 def handle(self) -> driver.CUgraphExec: 

779 """Return the underlying ``CUgraphExec`` object. 

780  

781 .. caution:: 

782  

783 This handle is a Python object. To get the memory address of the underlying C 

784 handle, call ``int()`` on the returned object. 

785  

786 """ 

787 return self._mnff.graph 

788  

789 def update(self, source: "GraphBuilder | GraphDefinition") -> None: 

790 """Update the graph using a new graph definition. 

791  

792 The topology of the provided source must be identical to this graph. 

793  

794 Parameters 

795 ---------- 

796 source : :obj:`~graph.GraphBuilder` or :obj:`~graph.GraphDefinition` 

797 The graph definition to update from. A GraphBuilder must have 

798 finished building. 

799  

800 """ 

801 from cuda.core.graph import GraphDefinition 1vH$ENM

802  

803 cdef cydriver.CUgraph cu_graph 

804 cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph) 1vH$ENM

805  

806 if isinstance(source, GraphBuilder): 1vH$ENM

807 if not source._building_ended: 1vHEN

808 raise ValueError("Graph has not finished building.") 1N

809 cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph) 1vHE

810 elif isinstance(source, GraphDefinition): 1$M

811 cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle) 1$

812 else: 

813 raise TypeError( 1M

814 f"expected GraphBuilder or GraphDefinition, got {type(source).__name__}") 1M

815  

816 cdef cydriver.CUgraphExecUpdateResultInfo result_info 

817 cdef cydriver.CUresult err 

818 with nogil: 1vH$E

819 err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) 1vH$E

820 if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: 1vH$E

821 reason = driver.CUgraphExecUpdateResult(result_info.result) 1E

822 msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" 1E

823 raise CUDAError(msg) 1E

824 HANDLE_RETURN(err) 1vH$

825  

826 def upload(self, stream: Stream): 

827 """Uploads the graph in a stream. 

828  

829 Parameters 

830 ---------- 

831 stream : :obj:`~_stream.Stream` 

832 The stream in which to upload the graph 

833  

834 """ 

835 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 2A O [ % ] ' ^ _ ` { ) | * , - } . : ~ ; abbbcb2 Q W 3 R X L J K 4 S Y 5 T Z

836  

837 def launch(self, stream: Stream): 

838 """Launches the graph in a stream. 

839  

840 Parameters 

841 ---------- 

842 stream : :obj:`~_stream.Stream` 

843 The stream in which to launch the graph 

844  

845 """ 

846 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 2F G C A V O b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = dbabbbebfbgbcbhbibjbkblbmb2 Q W 3 R X L J K 4 S Y 5 T Z v H $