Coverage for cuda / core / graph / _graph_builder.pyx: 88.28%

384 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-22 01:37 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5import weakref 

6from dataclasses import dataclass 

7  

8from libc.stdint cimport intptr_t 

9  

10from cuda.bindings cimport cydriver 

11  

12from cuda.core.graph._graph_definition cimport GraphCondition 

13from cuda.core.graph._utils cimport _attach_host_callback_to_graph 

14from cuda.core._resource_handles cimport as_cu 

15from cuda.core._stream cimport Stream 

16from cuda.core._utils.cuda_utils cimport HANDLE_RETURN 

17from cuda.core._utils.version cimport cy_binding_version, cy_driver_version 

18  

19from cuda.core._utils.cuda_utils import ( 

20 CUDAError, 

21 driver, 

22 handle_return, 

23) 

24  

25__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions'] 

26  

27  

28@dataclass 

29class GraphDebugPrintOptions: 

30 """Options for debug_dot_print(). 

31  

32 Attributes 

33 ---------- 

34 verbose : bool 

35 Output all debug data as if every debug flag is enabled (Default to False) 

36 runtime_types : bool 

37 Use CUDA Runtime structures for output (Default to False) 

38 kernel_node_params : bool 

39 Adds kernel parameter values to output (Default to False) 

40 memcpy_node_params : bool 

41 Adds memcpy parameter values to output (Default to False) 

42 memset_node_params : bool 

43 Adds memset parameter values to output (Default to False) 

44 host_node_params : bool 

45 Adds host parameter values to output (Default to False) 

46 event_node_params : bool 

47 Adds event parameter values to output (Default to False) 

48 ext_semas_signal_node_params : bool 

49 Adds external semaphore signal parameter values to output (Default to False) 

50 ext_semas_wait_node_params : bool 

51 Adds external semaphore wait parameter values to output (Default to False) 

52 kernel_node_attributes : bool 

53 Adds kernel node attributes to output (Default to False) 

54 handles : bool 

55 Adds node handles and every kernel function handle to output (Default to False) 

56 mem_alloc_node_params : bool 

57 Adds memory alloc parameter values to output (Default to False) 

58 mem_free_node_params : bool 

59 Adds memory free parameter values to output (Default to False) 

60 batch_mem_op_node_params : bool 

61 Adds batch mem op parameter values to output (Default to False) 

62 extra_topo_info : bool 

63 Adds edge numbering information (Default to False) 

64 conditional_node_params : bool 

65 Adds conditional node parameter values to output (Default to False) 

66  

67 """ 

68  

69 verbose: bool = False 

70 runtime_types: bool = False 

71 kernel_node_params: bool = False 

72 memcpy_node_params: bool = False 

73 memset_node_params: bool = False 

74 host_node_params: bool = False 

75 event_node_params: bool = False 

76 ext_semas_signal_node_params: bool = False 

77 ext_semas_wait_node_params: bool = False 

78 kernel_node_attributes: bool = False 

79 handles: bool = False 

80 mem_alloc_node_params: bool = False 

81 mem_free_node_params: bool = False 

82 batch_mem_op_node_params: bool = False 

83 extra_topo_info: bool = False 

84 conditional_node_params: bool = False 

85  

86 def _to_flags(self) -> int: 

87 """Convert options to CUDA driver API flags (internal use).""" 

88 flags = 0 2oba

89 if self.verbose: 2oba

90 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 2oba

91 if self.runtime_types: 2oba

92 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a

93 if self.kernel_node_params: 2oba

94 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a

95 if self.memcpy_node_params: 2oba

96 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a

97 if self.memset_node_params: 2oba

98 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a

99 if self.host_node_params: 2oba

100 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a

101 if self.event_node_params: 2oba

102 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a

103 if self.ext_semas_signal_node_params: 2oba

104 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a

105 if self.ext_semas_wait_node_params: 2oba

106 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a

107 if self.kernel_node_attributes: 2oba

108 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a

109 if self.handles: 2oba

110 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 2oba

111 if self.mem_alloc_node_params: 2oba

112 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a

113 if self.mem_free_node_params: 2oba

114 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a

115 if self.batch_mem_op_node_params: 2F oba

116 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a

117 if self.extra_topo_info: 2oba

118 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a

119 if self.conditional_node_params: 2oba

120 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a

121 return flags 2oba

122  

123  

124@dataclass 

125class GraphCompleteOptions: 

126 """Options for graph instantiation. 

127  

128 Attributes 

129 ---------- 

130 auto_free_on_launch : bool, optional 

131 Automatically free memory allocated in a graph before relaunching. (Default to False) 

132 upload_stream : Stream, optional 

133 Stream to use to automatically upload the graph after completion. (Default to None) 

134 device_launch : bool, optional 

135 Configure the graph to be launchable from the device. This flag can only 

136 be used on platforms which support unified addressing. This flag cannot be 

137 used in conjunction with auto_free_on_launch. (Default to False) 

138 use_node_priority : bool, optional 

139 Run the graph using the per-node priority attributes rather than the 

140 priority of the stream it is launched into. (Default to False) 

141  

142 """ 

143  

144 auto_free_on_launch: bool = False 

145 upload_stream: Stream | None = None 

146 device_launch: bool = False 

147 use_node_priority: bool = False 

148  

149  

150def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> "Graph": 

151 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

152 if options: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

153 flags = 0 1%'()*+,-./:;=?@LJKD

154 if options.auto_free_on_launch: 1%'()*+,-./:;=?@LJKD

155 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1%)-:?@LJKD

156 if options.upload_stream: 1%'()*+,-./:;=?@LJKD

157 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1(+/=D

158 params.hUploadStream = options.upload_stream.handle 1(+/=D

159 if options.device_launch: 1%'()*+,-./:;=?@LJKD

160 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1,D

161 if options.use_node_priority: 1%'()*+,-./:;=?@LJKD

162 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 1'*.;?@D

163 params.flags = flags 1%'()*+,-./:;=?@LJKD

164  

165 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

166 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

167 raise RuntimeError( 

168 "Instantiation failed for an unexpected reason which is described in the return value of the function." 

169 ) 

170 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

171 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.") 

172 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

173 raise RuntimeError( 

174 "Instantiation for device launch failed because the graph contained an unsupported operation." 

175 ) 

176 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

177 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.") 

178 elif ( 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

179 cy_binding_version() >= (12, 8, 0) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

180 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

181 ): 

182 raise RuntimeError("One or more conditional handles are not associated with conditional builders.") 

183 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

184 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") 

185 return graph 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

186  

187  

188class GraphBuilder: 

189 """A graph under construction by stream capture. 

190  

191 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

192 them with a specified dependency tree. It speeds up the workflow by combining the 

193 driver activities associated with CUDA kernel launches and CUDA API calls. 

194  

195 Directly creating a :obj:`~graph.GraphBuilder` is not supported due 

196 to ambiguity. New graph builders should instead be created through a 

197 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object. 

198  

199 """ 

200  

201 class _MembersNeededForFinalize: 

202 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") 

203  

204 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required): 

205 self.stream = stream_obj 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

206 self.is_stream_owner = is_stream_owner 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

207 self.graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

208 self.conditional_graph = conditional_graph 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

209 self.is_join_required = is_join_required 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

210 weakref.finalize(graph_builder_obj, self.close) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

211  

212 def close(self): 

213 if self.stream: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

214 if not self.is_join_required: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

215 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

216 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

217 # Note how this condition only occures for the primary graph builder 

218 # This is because calling cuStreamEndCapture streams that were split off of the primary 

219 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. 

220 # Therefore, it is currently a requirement that users join all split graph builders 

221 # before a graph builder can be clearly destroyed. 

222 handle_return(driver.cuStreamEndCapture(self.stream.handle)) 

223 if self.is_stream_owner: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

224 self.stream.close() 1UCA#BOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

225 self.stream = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

226 if self.graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

227 handle_return(driver.cuGraphDestroy(self.graph)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

228 self.graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

229 self.conditional_graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

230  

231 __slots__ = ("__weakref__", "_building_ended", "_mnff") 

232  

233 def __init__(self): 

234 raise NotImplementedError( 

235 "directly creating a Graph object can be ambiguous. Please either " 

236 "call Device.create_graph_builder() or stream.create_graph_builder()" 

237 ) 

238  

239 @classmethod 

240 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): 

241 self = cls.__new__(cls) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

242 self._mnff = GraphBuilder._MembersNeededForFinalize( 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

243 self, stream, is_stream_owner, conditional_graph, is_join_required 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

244 ) 

245  

246 self._building_ended = False 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

247 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

248  

249 @property 

250 def stream(self) -> Stream: 

251 """Returns the stream associated with the graph builder.""" 

252 return self._mnff.stream 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

253  

254 @property 

255 def is_join_required(self) -> bool: 

256 """Returns True if this graph builder must be joined before building is ended.""" 

257 return self._mnff.is_join_required 1ABbcdefghijklmnopqrstua

258  

259 def begin_building(self, mode="relaxed") -> GraphBuilder: 

260 """Begins the building process. 

261  

262 Build `mode` for controlling interaction with other API calls must be one of the following: 

263  

264 - `global` : Prohibit potentially unsafe operations across all streams in the process. 

265 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread. 

266 - `relaxed` : The local thread is not prohibited from potentially unsafe operations. 

267  

268 Parameters 

269 ---------- 

270 mode : str, optional 

271 Build mode to control the interaction with other API calls that are porentially unsafe. 

272 Default set to use relaxed. 

273  

274 """ 

275 if self._building_ended: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

276 raise RuntimeError("Cannot resume building after building has ended.") 1V

277 if mode not in ("global", "thread_local", "relaxed"): 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

278 raise ValueError(f"Unsupported build mode: {mode}") 1!

279 if mode == "global": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

280 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 18923L45!

281 elif mode == "thread_local": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz0617QWRXJKSYTZvIENM!Da

282 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 167WXKYZ!

283 elif mode == "relaxed": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvIENM!Da

284 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvIENM!Da

285 else: 

286 raise ValueError(f"Unsupported build mode: {mode}") 

287  

288 if self._mnff.conditional_graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

289 handle_return( 1bcdefghijklmnopqrstuwxyzva

290 driver.cuStreamBeginCaptureToGraph( 1bcdefghijklmnopqrstuwxyzva

291 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva

292 self._mnff.conditional_graph, 1bcdefghijklmnopqrstuwxyzva

293 None, # dependencies 

294 None, # dependencyData 

295 0, # numDependencies 

296 capture_mode, 1bcdefghijklmnopqrstuwxyzva

297 ) 

298 ) 

299 else: 

300 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

301 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

302  

303 @property 

304 def is_building(self) -> bool: 

305 """Returns True if the graph builder is currently building.""" 

306 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

307 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

308 return False 1#

309 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

310 return True 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

311 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: 

312 raise RuntimeError( 

313 "Build process encountered an error and has been invalidated. Build process must now be ended." 

314 ) 

315 else: 

316 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") 

317  

318 def end_building(self) -> GraphBuilder: 

319 """Ends the building process.""" 

320 if not self.is_building: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

321 raise RuntimeError("Graph builder is not building.") 

322 if self._mnff.conditional_graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

323 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1bcdefghijklmnopqrstuwxyzva

324 else: 

325 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

326  

327 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to 

328 # resume the build process after the first call to end_building() 

329 self._building_ended = True 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

330 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da

331  

332 def complete(self, options: GraphCompleteOptions | None = None) -> "Graph": 

333 """Completes the graph builder and returns the built :obj:`~graph.Graph` object. 

334  

335 Parameters 

336 ---------- 

337 options : :obj:`~graph.GraphCompleteOptions`, optional 

338 Customizable dataclass for the graph builder completion options. 

339  

340 Returns 

341 ------- 

342 graph : :obj:`~graph.Graph` 

343 The newly built graph. 

344  

345 """ 

346 if not self._building_ended: 1GHUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENMD

347 raise RuntimeError("Graph has not finished building.") 1U

348  

349 return _instantiate_graph(self._mnff.graph, options) 1GHUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENMD

350  

351 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): 

352 """Generates a DOT debug file for the graph builder. 

353  

354 Parameters 

355 ---------- 

356 path : str 

357 File path to use for writting debug DOT output 

358 options : :obj:`~graph.GraphDebugPrintOptions`, optional 

359 Customizable dataclass for the debug print options. 

360  

361 """ 

362 if not self._building_ended: 1a

363 raise RuntimeError("Graph has not finished building.") 

364 flags = options._to_flags() if options else 0 1a

365 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) 1a

366  

367 def split(self, count: int) -> tuple[GraphBuilder, ...]: 

368 """Splits the original graph builder into multiple graph builders. 

369  

370 The new builders inherit work dependencies from the original builder. 

371 The original builder is reused for the split and is returned first in the tuple. 

372  

373 Parameters 

374 ---------- 

375 count : int 

376 The number of graph builders to split the graph builder into. 

377  

378 Returns 

379 ------- 

380 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...] 

381 A tuple of split graph builders. The first graph builder in the tuple 

382 is always the original graph builder. 

383  

384 """ 

385 if count < 2: 1ABbcdefghijklmnopqrstua

386 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1A

387  

388 event = self._mnff.stream.record() 1ABbcdefghijklmnopqrstua

389 result = [self] 1ABbcdefghijklmnopqrstua

390 for i in range(count - 1): 1ABbcdefghijklmnopqrstua

391 stream = self._mnff.stream.device.create_stream() 1ABbcdefghijklmnopqrstua

392 stream.wait(event) 1ABbcdefghijklmnopqrstua

393 result.append( 1ABbcdefghijklmnopqrstua

394 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) 1ABbcdefghijklmnopqrstua

395 ) 

396 event.close() 1ABbcdefghijklmnopqrstua

397 return tuple(result) 1ABbcdefghijklmnopqrstua

398  

399 @staticmethod 

400 def join(*graph_builders) -> GraphBuilder: 

401 """Joins multiple graph builders into a single graph builder. 

402  

403 The returned builder inherits work dependencies from the provided builders. 

404  

405 Parameters 

406 ---------- 

407 *graph_builders : :obj:`~graph.GraphBuilder` 

408 The graph builders to join. 

409  

410 Returns 

411 ------- 

412 graph_builder : :obj:`~graph.GraphBuilder` 

413 The newly joined graph builder. 

414  

415 """ 

416 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1ABbcdefghijklmnopqrstua

417 raise TypeError("All arguments must be GraphBuilder instances") 

418 if len(graph_builders) < 2: 1ABbcdefghijklmnopqrstua

419 raise ValueError("Must join with at least two graph builders") 1A

420  

421 # Discover the root builder others should join 

422 root_idx = 0 1ABbcdefghijklmnopqrstua

423 for i, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua

424 if not builder.is_join_required: 1ABbcdefghijklmnopqrstua

425 root_idx = i 1ABbcdefghijklmnopqrstua

426 break 1ABbcdefghijklmnopqrstua

427  

428 # Join all onto the root builder 

429 root_bdr = graph_builders[root_idx] 1ABbcdefghijklmnopqrstua

430 for idx, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua

431 if idx == root_idx: 1ABbcdefghijklmnopqrstua

432 continue 1ABbcdefghijklmnopqrstua

433 root_bdr.stream.wait(builder.stream) 1ABbcdefghijklmnopqrstua

434 builder.close() 1ABbcdefghijklmnopqrstua

435  

436 return root_bdr 1ABbcdefghijklmnopqrstua

437  

438 def __cuda_stream__(self) -> tuple[int, int]: 

439 """Return an instance of a __cuda_stream__ protocol.""" 

440 return self.stream.__cuda_stream__() 

441  

442 def _get_conditional_context(self) -> driver.CUcontext: 

443 return self._mnff.stream.context.handle 1bcdefghijklmnopqrstuwxyzva

444  

445 def create_condition(self, default_value=None) -> GraphCondition: 

446 """Create a condition variable for use with conditional nodes. 

447  

448 The returned :class:`GraphCondition` object is passed to conditional-node 

449 builder methods (:meth:`if_then`, :meth:`if_else`, :meth:`while_loop`, 

450 :meth:`switch`). Its value is controlled at runtime by device code via 

451 ``cudaGraphSetConditional``. 

452  

453 Parameters 

454 ---------- 

455 default_value : int, optional 

456 The default value to assign to the condition. If None, no 

457 default is assigned. 

458  

459 Returns 

460 ------- 

461 GraphCondition 

462 A condition variable for controlling conditional execution. 

463 """ 

464 if cy_driver_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva

465 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles") 

466 if cy_binding_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva

467 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles") 

468 if default_value is not None: 1bcdefghijklmnopqrstuwxyzva

469 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wxyzv

470 else: 

471 default_value = 0 1bcdefghijklmnopqrstua

472 flags = 0 1bcdefghijklmnopqrstua

473  

474 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1bcdefghijklmnopqrstuwxyzva

475 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva

476 raise RuntimeError("Cannot create a condition when graph is not being built") 

477  

478 raw_handle = handle_return( 1bcdefghijklmnopqrstuwxyzva

479 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 1bcdefghijklmnopqrstuwxyzva

480 ) 

481 return GraphCondition._from_handle(<cydriver.CUgraphConditionalHandle><intptr_t>int(raw_handle)) 1bcdefghijklmnopqrstuwxyzva

482  

483 def _cond_with_params(self, node_params) -> tuple: 

484 # Get current capture info to ensure we're in a valid state 

485 status, _, graph, *deps_info, num_dependencies = handle_return( 1bcdefghijklmnopqrstuwxyzva

486 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) 1bcdefghijklmnopqrstuwxyzva

487 ) 

488 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva

489 raise RuntimeError("Cannot add conditional node when not actively capturing") 

490  

491 # Add the conditional node to the graph 

492 deps_info_update = [ 1bcdefghijklmnopqrstuwxyzva

493 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 1bcdefghijklmnopqrstuwxyzva

494 ] + [None] * (len(deps_info) - 1) 1bcdefghijklmnopqrstuwxyzva

495  

496 # Update the stream's capture dependencies 

497 handle_return( 1bcdefghijklmnopqrstuwxyzva

498 driver.cuStreamUpdateCaptureDependencies( 1bcdefghijklmnopqrstuwxyzva

499 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva

500 *deps_info_update, # dependencies, edgeData 1bcdefghijklmnopqrstuwxyzva

501 1, # numDependencies 

502 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1bcdefghijklmnopqrstuwxyzva

503 ) 

504 ) 

505  

506 # Create new graph builders for each condition 

507 return tuple( 1bcdefghijklmnopqrstuwxyzva

508 [ 1bcdefghijklmnopqrstuwxyzva

509 GraphBuilder._init( 1bcdefghijklmnopqrstuwxyzva

510 stream=self._mnff.stream.device.create_stream(), 1bcdefghijklmnopqrstuwxyzva

511 is_stream_owner=True, 

512 conditional_graph=node_params.conditional.phGraph_out[i], 1bcdefghijklmnopqrstuwxyzva

513 is_join_required=False, 1bcdefghijklmnopqrstuwxyzva

514 ) 

515 for i in range(node_params.conditional.size) 1bcdefghijklmnopqrstuwxyzva

516 ] 

517 ) 

518  

519 def if_then(self, condition: GraphCondition) -> GraphBuilder: 

520 """Adds an if condition branch and returns a new graph builder for it. 

521  

522 The resulting if graph will only execute the branch if the 

523 condition evaluates to true at runtime. 

524  

525 The new builder inherits work dependencies from the original builder. 

526  

527 Parameters 

528 ---------- 

529 condition : :class:`~graph.GraphCondition` 

530 The condition variable from :meth:`create_condition` controlling 

531 whether the branch executes. 

532  

533 Returns 

534 ------- 

535 graph_builder : :obj:`~graph.GraphBuilder` 

536 The newly created conditional graph builder. 

537  

538 """ 

539 if cy_driver_version() < (12, 3, 0): 1bcdefghia

540 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if") 

541 if cy_binding_version() < (12, 3, 0): 1bcdefghia

542 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if") 

543 if not isinstance(condition, GraphCondition): 1bcdefghia

544 raise TypeError( 

545 f"condition must be a GraphCondition object (from " 

546 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

547 node_params = driver.CUgraphNodeParams() 1bcdefghia

548 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia

549 node_params.conditional.handle = condition.handle 1bcdefghia

550 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia

551 node_params.conditional.size = 1 1bcdefghia

552 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia

553 return self._cond_with_params(node_params)[0] 1bcdefghia

554  

555 def if_else(self, condition: GraphCondition) -> tuple[GraphBuilder, GraphBuilder]: 

556 """Adds an if-else condition branch and returns new graph builders for both branches. 

557  

558 The resulting if graph will execute the branch if the condition 

559 evaluates to true at runtime, otherwise the else branch will execute. 

560  

561 The new builders inherit work dependencies from the original builder. 

562  

563 Parameters 

564 ---------- 

565 condition : :class:`~graph.GraphCondition` 

566 The condition variable from :meth:`create_condition` controlling 

567 which branch executes. 

568  

569 Returns 

570 ------- 

571 graph_builders : tuple[:obj:`~graph.GraphBuilder`, :obj:`~graph.GraphBuilder`] 

572 A tuple of two new graph builders, one for the if branch and one for the else branch. 

573  

574 """ 

575 if cy_driver_version() < (12, 8, 0): 1jklmnopq

576 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else") 

577 if cy_binding_version() < (12, 8, 0): 1jklmnopq

578 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else") 

579 if not isinstance(condition, GraphCondition): 1jklmnopq

580 raise TypeError( 

581 f"condition must be a GraphCondition object (from " 

582 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

583 node_params = driver.CUgraphNodeParams() 1jklmnopq

584 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq

585 node_params.conditional.handle = condition.handle 1jklmnopq

586 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq

587 node_params.conditional.size = 2 1jklmnopq

588 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq

589 return self._cond_with_params(node_params) 1jklmnopq

590  

591 def switch(self, condition: GraphCondition, count: int) -> tuple[GraphBuilder, ...]: 

592 """Adds a switch condition branch and returns new graph builders for all cases. 

593  

594 The resulting switch graph will execute the branch whose case index 

595 matches the value of the condition at runtime. If no match is found, no 

596 branch will be executed. 

597  

598 The new builders inherit work dependencies from the original builder. 

599  

600 Parameters 

601 ---------- 

602 condition : :class:`~graph.GraphCondition` 

603 The condition variable from :meth:`create_condition` selecting 

604 which case executes. 

605 count : int 

606 The number of cases to add to the switch conditional. 

607  

608 Returns 

609 ------- 

610 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...] 

611 A tuple of new graph builders, one for each branch. 

612  

613 """ 

614 if cy_driver_version() < (12, 8, 0): 1rstuv

615 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch") 

616 if cy_binding_version() < (12, 8, 0): 1rstuv

617 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch") 

618 if not isinstance(condition, GraphCondition): 1rstuv

619 raise TypeError( 

620 f"condition must be a GraphCondition object (from " 

621 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

622 node_params = driver.CUgraphNodeParams() 1rstuv

623 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1rstuv

624 node_params.conditional.handle = condition.handle 1rstuv

625 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1rstuv

626 node_params.conditional.size = count 1rstuv

627 node_params.conditional.ctx = self._get_conditional_context() 1rstuv

628 return self._cond_with_params(node_params) 1rstuv

629  

630 def while_loop(self, condition: GraphCondition) -> GraphBuilder: 

631 """Adds a while loop and returns a new graph builder for it. 

632  

633 The resulting while loop graph will execute the branch repeatedly at runtime 

634 until the condition evaluates to false. 

635  

636 The new builder inherits work dependencies from the original builder. 

637  

638 Parameters 

639 ---------- 

640 condition : :class:`~graph.GraphCondition` 

641 The condition variable from :meth:`create_condition` controlling 

642 loop continuation. 

643  

644 Returns 

645 ------- 

646 graph_builder : :obj:`~graph.GraphBuilder` 

647 The newly created while loop graph builder. 

648  

649 """ 

650 if cy_driver_version() < (12, 3, 0): 1wxyz

651 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop") 

652 if cy_binding_version() < (12, 3, 0): 1wxyz

653 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop") 

654 if not isinstance(condition, GraphCondition): 1wxyz

655 raise TypeError( 

656 f"condition must be a GraphCondition object (from " 

657 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

658 node_params = driver.CUgraphNodeParams() 1wxyz

659 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wxyz

660 node_params.conditional.handle = condition.handle 1wxyz

661 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1wxyz

662 node_params.conditional.size = 1 1wxyz

663 node_params.conditional.ctx = self._get_conditional_context() 1wxyz

664 return self._cond_with_params(node_params)[0] 1wxyz

665  

666 def close(self): 

667 """Destroy the graph builder. 

668  

669 Closes the associated stream if we own it. Borrowed stream 

670 object will instead have their references released. 

671  

672 """ 

673 self._mnff.close() 1ABPbcdefghijklmnopqrstua

674  

675 def embed(self, child: GraphBuilder): 

676 """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node. 

677  

678 Parameters 

679 ---------- 

680 child : :obj:`~graph.GraphBuilder` 

681 The child graph builder. Must have finished building. 

682 """ 

683 if not child._building_ended: 1C

684 raise ValueError("Child graph has not finished building.") 

685  

686 if not self.is_building: 1C

687 raise ValueError("Parent graph is not being built.") 

688  

689 stream_handle = self._mnff.stream.handle 1C

690 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1C

691 driver.cuStreamGetCaptureInfo(stream_handle) 1C

692 ) 

693  

694 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159 

695 # for rationale 

696 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1C

697 deps_info_update = [ 1C

698 [ 1C

699 handle_return( 1C

700 driver.cuGraphAddChildGraphNode( 1C

701 graph_out, *deps_info_trimmed, num_dependencies_out, child._mnff.graph 1C

702 ) 

703 ) 

704 ] 

705 ] + [None] * (len(deps_info_out) - 1) 1C

706 handle_return( 1C

707 driver.cuStreamUpdateCaptureDependencies( 1C

708 stream_handle, 1C

709 *deps_info_update, # dependencies, edgeData 

710 1, 

711 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1C

712 ) 

713 ) 

714  

715 def callback(self, fn, *, user_data=None): 

716 """Add a host callback to the graph during stream capture. 

717  

718 The callback runs on the host CPU when the graph reaches this point 

719 in execution. Two modes are supported: 

720  

721 - **Python callable**: Pass any callable. The GIL is acquired 

722 automatically. The callable must take no arguments; use closures 

723 or ``functools.partial`` to bind state. 

724 - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance. 

725 The function receives a single ``void*`` argument (the 

726 ``user_data``). The caller must keep the ctypes wrapper alive 

727 for the lifetime of the graph. 

728  

729 .. warning:: 

730  

731 Callbacks must not call CUDA API functions. Doing so may 

732 deadlock or corrupt driver state. 

733  

734 Parameters 

735 ---------- 

736 fn : callable or ctypes function pointer 

737 The callback function. 

738 user_data : int or bytes-like, optional 

739 Only for ctypes function pointers. If ``int``, passed as a raw 

740 pointer (caller manages lifetime). If bytes-like, the data is 

741 copied and its lifetime is tied to the graph. 

742 """ 

743 cdef Stream stream = <Stream>self._mnff.stream 1GH

744 cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) 1GH

745 cdef cydriver.CUstreamCaptureStatus capture_status 

746 cdef cydriver.CUgraph c_graph = NULL 1GH

747  

748 with nogil: 1GH

749 IF CUDA_CORE_BUILD_MAJOR >= 13: 

750 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 1GH

751 c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL)) 

752 ELSE: 

753 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 

754 c_stream, &capture_status, NULL, &c_graph, NULL, NULL)) 

755  

756 if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GH

757 raise RuntimeError("Cannot add callback when graph is not being built") 

758  

759 cdef cydriver.CUhostFn c_fn 

760 cdef void* c_user_data = NULL 1GH

761 _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) 1GH

762  

763 with nogil: 1GH

764 HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) 1GH

765  

766  

767class Graph: 

768 """An executable graph. 

769  

770 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

771 them with a specified dependency tree. It speeds up the workflow by combining the 

772 driver activities associated with CUDA kernel launches and CUDA API calls. 

773  

774 Graphs must be built using a :obj:`~graph.GraphBuilder` object. 

775  

776 """ 

777  

778 class _MembersNeededForFinalize: 

779 __slots__ = "graph" 

780  

781 def __init__(self, graph_obj, graph): 

782 self.graph = graph 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

783 weakref.finalize(graph_obj, self.close) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

784  

785 def close(self): 

786 if self.graph: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

787 handle_return(driver.cuGraphExecDestroy(self.graph)) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

788 self.graph = None 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

789  

790 __slots__ = ("__weakref__", "_mnff") 

791  

792 def __init__(self): 

793 raise RuntimeError("directly constructing a Graph instance is not supported") 

794  

795 @classmethod 

796 def _init(cls, graph): 

797 self = cls.__new__(cls) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

798 self._mnff = Graph._MembersNeededForFinalize(self, graph) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

799 return self 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D

800  

801 def close(self): 

802 """Destroy the graph.""" 

803 self._mnff.close() 1PD

804  

805 @property 

806 def handle(self) -> driver.CUgraphExec: 

807 """Return the underlying ``CUgraphExec`` object. 

808  

809 .. caution:: 

810  

811 This handle is a Python object. To get the memory address of the underlying C 

812 handle, call ``int()`` on the returned object. 

813  

814 """ 

815 return self._mnff.graph 

816  

817 def update(self, source: "GraphBuilder | GraphDefinition") -> None: 

818 """Update the graph using a new graph definition. 

819  

820 The topology of the provided source must be identical to this graph. 

821  

822 Parameters 

823 ---------- 

824 source : :obj:`~graph.GraphBuilder` or :obj:`~graph.GraphDefinition` 

825 The graph definition to update from. A GraphBuilder must have 

826 finished building. 

827  

828 """ 

829 from cuda.core.graph import GraphDefinition 1vI$ENM

830  

831 cdef cydriver.CUgraph cu_graph 

832 cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph) 1vI$ENM

833  

834 if isinstance(source, GraphBuilder): 1vI$ENM

835 if not source._building_ended: 1vIEN

836 raise ValueError("Graph has not finished building.") 1N

837 cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph) 1vIE

838 elif isinstance(source, GraphDefinition): 1$M

839 cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle) 1$

840 else: 

841 raise TypeError( 1M

842 f"expected GraphBuilder or GraphDefinition, got {type(source).__name__}") 1M

843  

844 cdef cydriver.CUgraphExecUpdateResultInfo result_info 

845 cdef cydriver.CUresult err 

846 with nogil: 1vI$E

847 err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) 1vI$E

848 if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: 1vI$E

849 reason = driver.CUgraphExecUpdateResult(result_info.result) 1E

850 msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" 1E

851 raise CUDAError(msg) 1E

852 HANDLE_RETURN(err) 1vI$

853  

854 def upload(self, stream: Stream): 

855 """Uploads the graph in a stream. 

856  

857 Parameters 

858 ---------- 

859 stream : :obj:`~_stream.Stream` 

860 The stream in which to upload the graph 

861  

862 """ 

863 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 2A O [ % ] ' ^ _ ` { ) | * , - } . : ~ ; abbbcb2 Q W 3 R X L J K 4 S Y 5 T Z

864  

865 def launch(self, stream: Stream): 

866 """Launches the graph in a stream. 

867  

868 Parameters 

869 ---------- 

870 stream : :obj:`~_stream.Stream` 

871 The stream in which to launch the graph. 

872  

873 """ 

874 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 2G H C A V O b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = dbabbbebfbgbhbcbibjbkblbmbnb2 Q W 3 R X L J K 4 S Y 5 T Z v I $