Coverage for cuda/core/graph/_graph_builder.pyx: 88.17%

389 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-13 01:38 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5import weakref 

6from dataclasses import dataclass 

7from typing import TYPE_CHECKING 

8  

9from libc.stdint cimport intptr_t 

10  

11from cuda.bindings cimport cydriver 

12  

13from cuda.core.graph._graph_definition cimport GraphCondition 

14from cuda.core.graph._utils cimport _attach_host_callback_to_graph 

15from cuda.core._resource_handles cimport as_cu 

16from cuda.core._stream cimport Stream 

17from cuda.core._utils.cuda_utils cimport HANDLE_RETURN 

18from cuda.core._utils.version cimport cy_binding_version, cy_driver_version 

19  

20from cuda.core._utils.cuda_utils import ( 

21 CUDAError, 

22 driver, 

23 handle_return, 

24) 

25  

26if TYPE_CHECKING: 

27 from cuda.core.graph._graph_definition import GraphDefinition 

28  

29__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions'] 

30  

31  

32@dataclass 

33class GraphDebugPrintOptions: 

34 """Options for debug_dot_print(). 

35  

36 Attributes 

37 ---------- 

38 verbose : bool 

39 Output all debug data as if every debug flag is enabled (Default to False) 

40 runtime_types : bool 

41 Use CUDA Runtime structures for output (Default to False) 

42 kernel_node_params : bool 

43 Adds kernel parameter values to output (Default to False) 

44 memcpy_node_params : bool 

45 Adds memcpy parameter values to output (Default to False) 

46 memset_node_params : bool 

47 Adds memset parameter values to output (Default to False) 

48 host_node_params : bool 

49 Adds host parameter values to output (Default to False) 

50 event_node_params : bool 

51 Adds event parameter values to output (Default to False) 

52 ext_semas_signal_node_params : bool 

53 Adds external semaphore signal parameter values to output (Default to False) 

54 ext_semas_wait_node_params : bool 

55 Adds external semaphore wait parameter values to output (Default to False) 

56 kernel_node_attributes : bool 

57 Adds kernel node attributes to output (Default to False) 

58 handles : bool 

59 Adds node handles and every kernel function handle to output (Default to False) 

60 mem_alloc_node_params : bool 

61 Adds memory alloc parameter values to output (Default to False) 

62 mem_free_node_params : bool 

63 Adds memory free parameter values to output (Default to False) 

64 batch_mem_op_node_params : bool 

65 Adds batch mem op parameter values to output (Default to False) 

66 extra_topo_info : bool 

67 Adds edge numbering information (Default to False) 

68 conditional_node_params : bool 

69 Adds conditional node parameter values to output (Default to False) 

70  

71 """ 

72  

73 verbose: bool = False 

74 runtime_types: bool = False 

75 kernel_node_params: bool = False 

76 memcpy_node_params: bool = False 

77 memset_node_params: bool = False 

78 host_node_params: bool = False 

79 event_node_params: bool = False 

80 ext_semas_signal_node_params: bool = False 

81 ext_semas_wait_node_params: bool = False 

82 kernel_node_attributes: bool = False 

83 handles: bool = False 

84 mem_alloc_node_params: bool = False 

85 mem_free_node_params: bool = False 

86 batch_mem_op_node_params: bool = False 

87 extra_topo_info: bool = False 

88 conditional_node_params: bool = False 

89  

90 def _to_flags(self) -> int: 

91 """Convert options to CUDA driver API flags (internal use).""" 

92 flags = 0 2oba

93 if self.verbose: 2oba

94 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 2oba

95 if self.runtime_types: 2oba

96 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a

97 if self.kernel_node_params: 2oba

98 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a

99 if self.memcpy_node_params: 2oba

100 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a

101 if self.memset_node_params: 2oba

102 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a

103 if self.host_node_params: 2oba

104 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a

105 if self.event_node_params: 2oba

106 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a

107 if self.ext_semas_signal_node_params: 2oba

108 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a

109 if self.ext_semas_wait_node_params: 2oba

110 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a

111 if self.kernel_node_attributes: 2oba

112 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a

113 if self.handles: 2oba

114 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 2oba

115 if self.mem_alloc_node_params: 2D oba

116 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a

117 if self.mem_free_node_params: 2oba

118 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a

119 if self.batch_mem_op_node_params: 2oba

120 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a

121 if self.extra_topo_info: 2oba

122 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a

123 if self.conditional_node_params: 2oba

124 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a

125 return flags 2oba

126  

127  

128@dataclass 

129class GraphCompleteOptions: 

130 """Options for graph instantiation. 

131  

132 Attributes 

133 ---------- 

134 auto_free_on_launch : bool, optional 

135 Automatically free memory allocated in a graph before relaunching. (Default to False) 

136 upload_stream : Stream, optional 

137 Stream to use to automatically upload the graph after completion. (Default to None) 

138 device_launch : bool, optional 

139 Configure the graph to be launchable from the device. This flag can only 

140 be used on platforms which support unified addressing. This flag cannot be 

141 used in conjunction with auto_free_on_launch. (Default to False) 

142 use_node_priority : bool, optional 

143 Run the graph using the per-node priority attributes rather than the 

144 priority of the stream it is launched into. (Default to False) 

145  

146 """ 

147  

148 auto_free_on_launch: bool = False 

149 upload_stream: Stream | None = None 

150 device_launch: bool = False 

151 use_node_priority: bool = False 

152  

153  

154def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> "Graph": 

155 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

156 if options: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

157 flags = 0 1%'()*+,-./:;=?@LJKE

158 if options.auto_free_on_launch: 1%'()*+,-./:;=?@LJKE

159 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1%)-:?@LJKE

160 if options.upload_stream: 1%'()*+,-./:;=?@LJKE

161 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1(+/=E

162 params.hUploadStream = options.upload_stream.handle 1(+/=E

163 if options.device_launch: 1%'()*+,-./:;=?@LJKE

164 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1,E

165 if options.use_node_priority: 1%'()*+,-./:;=?@LJKE

166 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 1'*.;?@E

167 params.flags = flags 1%'()*+,-./:;=?@LJKE

168  

169 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

170 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

171 raise RuntimeError( 

172 "Instantiation failed for an unexpected reason which is described in the return value of the function." 

173 ) 

174 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

175 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.") 

176 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

177 raise RuntimeError( 

178 "Instantiation for device launch failed because the graph contained an unsupported operation." 

179 ) 

180 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

181 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.") 

182 elif ( 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

183 cy_binding_version() >= (12, 8, 0) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

184 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

185 ): 

186 raise RuntimeError("One or more conditional handles are not associated with conditional builders.") 

187 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

188 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") 

189 return graph 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

190  

191  

192class GraphBuilder: 

193 """A graph under construction by stream capture. 

194  

195 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

196 them with a specified dependency tree. It speeds up the workflow by combining the 

197 driver activities associated with CUDA kernel launches and CUDA API calls. 

198  

199 Directly creating a :obj:`~graph.GraphBuilder` is not supported due 

200 to ambiguity. New graph builders should instead be created through a 

201 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object. 

202  

203 """ 

204  

205 class _MembersNeededForFinalize: 

206 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") 

207  

208 def __init__(self, graph_builder_obj: GraphBuilder, stream_obj: Stream | None, is_stream_owner: bool, conditional_graph, is_join_required: bool) -> None: 

209 self.stream = stream_obj 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

210 self.is_stream_owner = is_stream_owner 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

211 self.graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

212 self.conditional_graph = conditional_graph 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

213 self.is_join_required = is_join_required 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

214 weakref.finalize(graph_builder_obj, self.close) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

215  

216 def close(self) -> None: 

217 if self.stream: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

218 if not self.is_join_required: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

219 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

220 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

221 # Note how this condition only occures for the primary graph builder 

222 # This is because calling cuStreamEndCapture streams that were split off of the primary 

223 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. 

224 # Therefore, it is currently a requirement that users join all split graph builders 

225 # before a graph builder can be clearly destroyed. 

226 handle_return(driver.cuStreamEndCapture(self.stream.handle)) 

227 if self.is_stream_owner: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

228 self.stream.close() 1UCA#BOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

229 self.stream = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

230 if self.graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

231 handle_return(driver.cuGraphDestroy(self.graph)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

232 self.graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

233 self.conditional_graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

234  

235 __slots__ = ("__weakref__", "_building_ended", "_mnff") 

236  

237 def __init__(self) -> None: 

238 raise NotImplementedError( 

239 "directly creating a Graph object can be ambiguous. Please either " 

240 "call Device.create_graph_builder() or stream.create_graph_builder()" 

241 ) 

242  

243 @classmethod 

244 def _init(cls, stream: Stream | None, is_stream_owner: bool, conditional_graph: object = None, is_join_required: bool = False) -> GraphBuilder: 

245 self = cls.__new__(cls) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

246 self._mnff = GraphBuilder._MembersNeededForFinalize( 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

247 self, stream, is_stream_owner, conditional_graph, is_join_required 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

248 ) 

249  

250 self._building_ended = False 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

251 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

252  

253 @property 

254 def stream(self) -> Stream: 

255 """Returns the stream associated with the graph builder.""" 

256 return self._mnff.stream 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

257  

258 @property 

259 def is_join_required(self) -> bool: 

260 """Returns True if this graph builder must be joined before building is ended.""" 

261 return self._mnff.is_join_required 1ABbcdefghijklmnopqrstua

262  

263 def begin_building(self, mode: str | None = "relaxed") -> GraphBuilder: 

264 """Begins the building process. 

265  

266 Build `mode` for controlling interaction with other API calls must be one of the following: 

267  

268 - `global` : Prohibit potentially unsafe operations across all streams in the process. 

269 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread. 

270 - `relaxed` : The local thread is not prohibited from potentially unsafe operations. 

271  

272 Parameters 

273 ---------- 

274 mode : str, optional 

275 Build mode to control the interaction with other API calls that are porentially unsafe. 

276 Default set to use relaxed. 

277  

278 """ 

279 if self._building_ended: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

280 raise RuntimeError("Cannot resume building after building has ended.") 1V

281 if mode not in ("global", "thread_local", "relaxed"): 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

282 raise ValueError(f"Unsupported build mode: {mode}") 1!

283 if mode == "global": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

284 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 18923L45!

285 elif mode == "thread_local": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz0617QWRXJKSYTZvIFNM!Ea

286 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 167WXKYZ!

287 elif mode == "relaxed": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvIFNM!Ea

288 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvIFNM!Ea

289 else: 

290 raise ValueError(f"Unsupported build mode: {mode}") 

291  

292 if self._mnff.conditional_graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

293 handle_return( 1bcdefghijklmnopqrstuwxyzva

294 driver.cuStreamBeginCaptureToGraph( 1bcdefghijklmnopqrstuwxyzva

295 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva

296 self._mnff.conditional_graph, 1bcdefghijklmnopqrstuwxyzva

297 None, # dependencies 

298 None, # dependencyData 

299 0, # numDependencies 

300 capture_mode, 1bcdefghijklmnopqrstuwxyzva

301 ) 

302 ) 

303 else: 

304 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

305 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

306  

307 @property 

308 def is_building(self) -> bool: 

309 """Returns True if the graph builder is currently building.""" 

310 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

311 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

312 return False 1#

313 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

314 return True 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

315 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: 

316 raise RuntimeError( 

317 "Build process encountered an error and has been invalidated. Build process must now be ended." 

318 ) 

319 else: 

320 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") 

321  

322 def end_building(self) -> GraphBuilder: 

323 """Ends the building process.""" 

324 if not self.is_building: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

325 raise RuntimeError("Graph builder is not building.") 

326 if self._mnff.conditional_graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

327 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1bcdefghijklmnopqrstuwxyzva

328 else: 

329 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

330  

331 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to 

332 # resume the build process after the first call to end_building() 

333 self._building_ended = True 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

334 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea

335  

336 def complete(self, options: GraphCompleteOptions | None = None) -> "Graph": 

337 """Completes the graph builder and returns the built :obj:`~graph.Graph` object. 

338  

339 Parameters 

340 ---------- 

341 options : :obj:`~graph.GraphCompleteOptions`, optional 

342 Customizable dataclass for the graph builder completion options. 

343  

344 Returns 

345 ------- 

346 graph : :obj:`~graph.Graph` 

347 The newly built graph. 

348  

349 """ 

350 if not self._building_ended: 1GHUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNME

351 raise RuntimeError("Graph has not finished building.") 1U

352  

353 return _instantiate_graph(self._mnff.graph, options) 1GHUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNME

354  

355 def debug_dot_print(self, path: str, options: GraphDebugPrintOptions | None = None) -> None: 

356 """Generates a DOT debug file for the graph builder. 

357  

358 Parameters 

359 ---------- 

360 path : str 

361 File path to use for writting debug DOT output 

362 options : :obj:`~graph.GraphDebugPrintOptions`, optional 

363 Customizable dataclass for the debug print options. 

364  

365 """ 

366 if not self._building_ended: 1a

367 raise RuntimeError("Graph has not finished building.") 

368 flags = options._to_flags() if options else 0 1a

369 cdef bytes path_bytes = path.encode('utf-8') 1a

370 cdef const char* c_path = path_bytes 1a

371 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, c_path, flags)) 1a

372  

373 def split(self, count: int) -> tuple[GraphBuilder, ...]: 

374 """Splits the original graph builder into multiple graph builders. 

375  

376 The new builders inherit work dependencies from the original builder. 

377 The original builder is reused for the split and is returned first in the tuple. 

378  

379 Parameters 

380 ---------- 

381 count : int 

382 The number of graph builders to split the graph builder into. 

383  

384 Returns 

385 ------- 

386 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...] 

387 A tuple of split graph builders. The first graph builder in the tuple 

388 is always the original graph builder. 

389  

390 """ 

391 if count < 2: 1ABbcdefghijklmnopqrstua

392 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1A

393  

394 event = self._mnff.stream.record() 1ABbcdefghijklmnopqrstua

395 result = [self] 1ABbcdefghijklmnopqrstua

396 for i in range(count - 1): 1ABbcdefghijklmnopqrstua

397 stream = self._mnff.stream.device.create_stream() 1ABbcdefghijklmnopqrstua

398 stream.wait(event) 1ABbcdefghijklmnopqrstua

399 result.append( 1ABbcdefghijklmnopqrstua

400 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) 1ABbcdefghijklmnopqrstua

401 ) 

402 event.close() 1ABbcdefghijklmnopqrstua

403 return tuple(result) 1ABbcdefghijklmnopqrstua

404  

405 @staticmethod 

406 def join(*graph_builders: GraphBuilder) -> GraphBuilder: 

407 """Joins multiple graph builders into a single graph builder. 

408  

409 The returned builder inherits work dependencies from the provided builders. 

410  

411 Parameters 

412 ---------- 

413 *graph_builders : :obj:`~graph.GraphBuilder` 

414 The graph builders to join. 

415  

416 Returns 

417 ------- 

418 graph_builder : :obj:`~graph.GraphBuilder` 

419 The newly joined graph builder. 

420  

421 """ 

422 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1ABbcdefghijklmnopqrstua

423 raise TypeError("All arguments must be GraphBuilder instances") 

424 if len(graph_builders) < 2: 1ABbcdefghijklmnopqrstua

425 raise ValueError("Must join with at least two graph builders") 1A

426  

427 # Discover the root builder others should join 

428 root_idx = 0 1ABbcdefghijklmnopqrstua

429 for i, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua

430 if not builder.is_join_required: 1ABbcdefghijklmnopqrstua

431 root_idx = i 1ABbcdefghijklmnopqrstua

432 break 1ABbcdefghijklmnopqrstua

433  

434 # Join all onto the root builder 

435 root_bdr = graph_builders[root_idx] 1ABbcdefghijklmnopqrstua

436 for idx, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua

437 if idx == root_idx: 1ABbcdefghijklmnopqrstua

438 continue 1ABbcdefghijklmnopqrstua

439 root_bdr.stream.wait(builder.stream) 1ABbcdefghijklmnopqrstua

440 builder.close() 1ABbcdefghijklmnopqrstua

441  

442 return root_bdr 1ABbcdefghijklmnopqrstua

443  

444 def __cuda_stream__(self) -> tuple[int, int]: 

445 """Return an instance of a __cuda_stream__ protocol.""" 

446 return self.stream.__cuda_stream__() 

447  

448 def _get_conditional_context(self) -> driver.CUcontext: 

449 return self._mnff.stream.context.handle 1bcdefghijklmnopqrstuwxyzva

450  

451 def create_condition(self, default_value: int | None = None) -> GraphCondition: 

452 """Create a condition variable for use with conditional nodes. 

453  

454 The returned :class:`GraphCondition` object is passed to conditional-node 

455 builder methods (:meth:`if_then`, :meth:`if_else`, :meth:`while_loop`, 

456 :meth:`switch`). Its value is controlled at runtime by device code via 

457 ``cudaGraphSetConditional``. 

458  

459 Parameters 

460 ---------- 

461 default_value : int, optional 

462 The default value to assign to the condition. If None, no 

463 default is assigned. 

464  

465 Returns 

466 ------- 

467 GraphCondition 

468 A condition variable for controlling conditional execution. 

469 """ 

470 if cy_driver_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva

471 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles") 

472 if cy_binding_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva

473 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles") 

474 if default_value is not None: 1bcdefghijklmnopqrstuwxyzva

475 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wxyzv

476 else: 

477 default_value = 0 1bcdefghijklmnopqrstua

478 flags = 0 1bcdefghijklmnopqrstua

479  

480 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1bcdefghijklmnopqrstuwxyzva

481 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva

482 raise RuntimeError("Cannot create a condition when graph is not being built") 

483  

484 raw_handle = handle_return( 1bcdefghijklmnopqrstuwxyzva

485 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 1bcdefghijklmnopqrstuwxyzva

486 ) 

487 return GraphCondition._from_handle(<cydriver.CUgraphConditionalHandle><intptr_t>int(raw_handle)) 1bcdefghijklmnopqrstuwxyzva

488  

489 def _cond_with_params(self, node_params: object) -> tuple[GraphBuilder, ...]: 

490 # Get current capture info to ensure we're in a valid state 

491 status, _, graph, *deps_info, num_dependencies = handle_return( 1bcdefghijklmnopqrstuwxyzva

492 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) 1bcdefghijklmnopqrstuwxyzva

493 ) 

494 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva

495 raise RuntimeError("Cannot add conditional node when not actively capturing") 

496  

497 # Add the conditional node to the graph 

498 deps_info_update = [ 1bcdefghijklmnopqrstuwxyzva

499 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 1bcdefghijklmnopqrstuwxyzva

500 ] + [None] * (len(deps_info) - 1) 1bcdefghijklmnopqrstuwxyzva

501  

502 # Update the stream's capture dependencies 

503 handle_return( 1bcdefghijklmnopqrstuwxyzva

504 driver.cuStreamUpdateCaptureDependencies( 1bcdefghijklmnopqrstuwxyzva

505 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva

506 *deps_info_update, # dependencies, edgeData 1bcdefghijklmnopqrstuwxyzva

507 1, # numDependencies 

508 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1bcdefghijklmnopqrstuwxyzva

509 ) 

510 ) 

511  

512 # Create new graph builders for each condition 

513 return tuple( 1bcdefghijklmnopqrstuwxyzva

514 [ 1bcdefghijklmnopqrstuwxyzva

515 GraphBuilder._init( 1bcdefghijklmnopqrstuwxyzva

516 stream=self._mnff.stream.device.create_stream(), 1bcdefghijklmnopqrstuwxyzva

517 is_stream_owner=True, 

518 conditional_graph=node_params.conditional.phGraph_out[i], 1bcdefghijklmnopqrstuwxyzva

519 is_join_required=False, 1bcdefghijklmnopqrstuwxyzva

520 ) 

521 for i in range(node_params.conditional.size) 1bcdefghijklmnopqrstuwxyzva

522 ] 

523 ) 

524  

525 def if_then(self, condition: GraphCondition) -> GraphBuilder: 

526 """Adds an if condition branch and returns a new graph builder for it. 

527  

528 The resulting if graph will only execute the branch if the 

529 condition evaluates to true at runtime. 

530  

531 The new builder inherits work dependencies from the original builder. 

532  

533 Parameters 

534 ---------- 

535 condition : :class:`~graph.GraphCondition` 

536 The condition variable from :meth:`create_condition` controlling 

537 whether the branch executes. 

538  

539 Returns 

540 ------- 

541 graph_builder : :obj:`~graph.GraphBuilder` 

542 The newly created conditional graph builder. 

543  

544 """ 

545 if cy_driver_version() < (12, 3, 0): 1bcdefghia

546 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if") 

547 if cy_binding_version() < (12, 3, 0): 1bcdefghia

548 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if") 

549 if not isinstance(condition, GraphCondition): 1bcdefghia

550 raise TypeError( 

551 f"condition must be a GraphCondition object (from " 

552 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

553 node_params = driver.CUgraphNodeParams() 1bcdefghia

554 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia

555 node_params.conditional.handle = condition.handle 1bcdefghia

556 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia

557 node_params.conditional.size = 1 1bcdefghia

558 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia

559 return self._cond_with_params(node_params)[0] 1bcdefghia

560  

561 def if_else(self, condition: GraphCondition) -> tuple[GraphBuilder, GraphBuilder]: 

562 """Adds an if-else condition branch and returns new graph builders for both branches. 

563  

564 The resulting if graph will execute the branch if the condition 

565 evaluates to true at runtime, otherwise the else branch will execute. 

566  

567 The new builders inherit work dependencies from the original builder. 

568  

569 Parameters 

570 ---------- 

571 condition : :class:`~graph.GraphCondition` 

572 The condition variable from :meth:`create_condition` controlling 

573 which branch executes. 

574  

575 Returns 

576 ------- 

577 graph_builders : tuple[:obj:`~graph.GraphBuilder`, :obj:`~graph.GraphBuilder`] 

578 A tuple of two new graph builders, one for the if branch and one for the else branch. 

579  

580 """ 

581 if cy_driver_version() < (12, 8, 0): 1jklmnopq

582 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else") 

583 if cy_binding_version() < (12, 8, 0): 1jklmnopq

584 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else") 

585 if not isinstance(condition, GraphCondition): 1jklmnopq

586 raise TypeError( 

587 f"condition must be a GraphCondition object (from " 

588 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

589 node_params = driver.CUgraphNodeParams() 1jklmnopq

590 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq

591 node_params.conditional.handle = condition.handle 1jklmnopq

592 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq

593 node_params.conditional.size = 2 1jklmnopq

594 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq

595 return self._cond_with_params(node_params) 1jklmnopq

596  

597 def switch(self, condition: GraphCondition, count: int) -> tuple[GraphBuilder, ...]: 

598 """Adds a switch condition branch and returns new graph builders for all cases. 

599  

600 The resulting switch graph will execute the branch whose case index 

601 matches the value of the condition at runtime. If no match is found, no 

602 branch will be executed. 

603  

604 The new builders inherit work dependencies from the original builder. 

605  

606 Parameters 

607 ---------- 

608 condition : :class:`~graph.GraphCondition` 

609 The condition variable from :meth:`create_condition` selecting 

610 which case executes. 

611 count : int 

612 The number of cases to add to the switch conditional. 

613  

614 Returns 

615 ------- 

616 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...] 

617 A tuple of new graph builders, one for each branch. 

618  

619 """ 

620 if cy_driver_version() < (12, 8, 0): 1rstuv

621 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch") 

622 if cy_binding_version() < (12, 8, 0): 1rstuv

623 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch") 

624 if not isinstance(condition, GraphCondition): 1rstuv

625 raise TypeError( 

626 f"condition must be a GraphCondition object (from " 

627 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

628 node_params = driver.CUgraphNodeParams() 1rstuv

629 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1rstuv

630 node_params.conditional.handle = condition.handle 1rstuv

631 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1rstuv

632 node_params.conditional.size = count 1rstuv

633 node_params.conditional.ctx = self._get_conditional_context() 1rstuv

634 return self._cond_with_params(node_params) 1rstuv

635  

636 def while_loop(self, condition: GraphCondition) -> GraphBuilder: 

637 """Adds a while loop and returns a new graph builder for it. 

638  

639 The resulting while loop graph will execute the branch repeatedly at runtime 

640 until the condition evaluates to false. 

641  

642 The new builder inherits work dependencies from the original builder. 

643  

644 Parameters 

645 ---------- 

646 condition : :class:`~graph.GraphCondition` 

647 The condition variable from :meth:`create_condition` controlling 

648 loop continuation. 

649  

650 Returns 

651 ------- 

652 graph_builder : :obj:`~graph.GraphBuilder` 

653 The newly created while loop graph builder. 

654  

655 """ 

656 if cy_driver_version() < (12, 3, 0): 1wxyz

657 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop") 

658 if cy_binding_version() < (12, 3, 0): 1wxyz

659 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop") 

660 if not isinstance(condition, GraphCondition): 1wxyz

661 raise TypeError( 

662 f"condition must be a GraphCondition object (from " 

663 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

664 node_params = driver.CUgraphNodeParams() 1wxyz

665 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wxyz

666 node_params.conditional.handle = condition.handle 1wxyz

667 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1wxyz

668 node_params.conditional.size = 1 1wxyz

669 node_params.conditional.ctx = self._get_conditional_context() 1wxyz

670 return self._cond_with_params(node_params)[0] 1wxyz

671  

672 def close(self) -> None: 

673 """Destroy the graph builder. 

674  

675 Closes the associated stream if we own it. Borrowed stream 

676 object will instead have their references released. 

677  

678 """ 

679 self._mnff.close() 1ABPbcdefghijklmnopqrstua

680  

681 def embed(self, child: GraphBuilder) -> None: 

682 """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node. 

683  

684 Parameters 

685 ---------- 

686 child : :obj:`~graph.GraphBuilder` 

687 The child graph builder. Must have finished building. 

688 """ 

689 if not child._building_ended: 1C

690 raise ValueError("Child graph has not finished building.") 

691  

692 if not self.is_building: 1C

693 raise ValueError("Parent graph is not being built.") 

694  

695 stream_handle = self._mnff.stream.handle 1C

696 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1C

697 driver.cuStreamGetCaptureInfo(stream_handle) 1C

698 ) 

699  

700 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159 

701 # for rationale 

702 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1C

703 deps_info_update = [ 1C

704 [ 1C

705 handle_return( 1C

706 driver.cuGraphAddChildGraphNode( 1C

707 graph_out, *deps_info_trimmed, num_dependencies_out, child._mnff.graph 1C

708 ) 

709 ) 

710 ] 

711 ] + [None] * (len(deps_info_out) - 1) 1C

712 handle_return( 1C

713 driver.cuStreamUpdateCaptureDependencies( 1C

714 stream_handle, 1C

715 *deps_info_update, # dependencies, edgeData 

716 1, 

717 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1C

718 ) 

719 ) 

720  

721 def callback(self, fn, *, user_data=None) -> None: 

722 """Add a host callback to the graph during stream capture. 

723  

724 The callback runs on the host CPU when the graph reaches this point 

725 in execution. Two modes are supported: 

726  

727 - **Python callable**: Pass any callable. The GIL is acquired 

728 automatically. The callable must take no arguments; use closures 

729 or ``functools.partial`` to bind state. 

730 - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance. 

731 The function receives a single ``void*`` argument (the 

732 ``user_data``). The caller must keep the ctypes wrapper alive 

733 for the lifetime of the graph. 

734  

735 .. warning:: 

736  

737 Callbacks must not call CUDA API functions. Doing so may 

738 deadlock or corrupt driver state. 

739  

740 Parameters 

741 ---------- 

742 fn : callable or ctypes function pointer 

743 The callback function. 

744 user_data : int or bytes-like, optional 

745 Only for ctypes function pointers. If ``int``, passed as a raw 

746 pointer (caller manages lifetime). If bytes-like, the data is 

747 copied and its lifetime is tied to the graph. 

748 """ 

749 cdef Stream stream = <Stream>self._mnff.stream 1GH

750 cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) 1GH

751 cdef cydriver.CUstreamCaptureStatus capture_status 

752 cdef cydriver.CUgraph c_graph = NULL 1GH

753  

754 with nogil: 1GH

755 IF CUDA_CORE_BUILD_MAJOR >= 13: 

756 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 1GH

757 c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL)) 

758 ELSE: 

759 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 

760 c_stream, &capture_status, NULL, &c_graph, NULL, NULL)) 

761  

762 if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GH

763 raise RuntimeError("Cannot add callback when graph is not being built") 

764  

765 cdef cydriver.CUhostFn c_fn 

766 cdef void* c_user_data = NULL 1GH

767 _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) 1GH

768  

769 with nogil: 1GH

770 HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) 1GH

771  

772  

773class Graph: 

774 """An executable graph. 

775  

776 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

777 them with a specified dependency tree. It speeds up the workflow by combining the 

778 driver activities associated with CUDA kernel launches and CUDA API calls. 

779  

780 Graphs must be built using a :obj:`~graph.GraphBuilder` object. 

781  

782 """ 

783  

784 class _MembersNeededForFinalize: 

785 __slots__ = "graph" 

786  

787 def __init__(self, graph_obj: Graph, graph: driver.CUgraphExec) -> None: 

788 self.graph = graph 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

789 weakref.finalize(graph_obj, self.close) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

790  

791 def close(self) -> None: 

792 if self.graph: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

793 handle_return(driver.cuGraphExecDestroy(self.graph)) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

794 self.graph = None 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

795  

796 __slots__ = ("__weakref__", "_mnff") 

797  

798 def __init__(self) -> None: 

799 raise RuntimeError("directly constructing a Graph instance is not supported") 

800  

801 @classmethod 

802 def _init(cls, graph: driver.CUgraphExec) -> Graph: 

803 self = cls.__new__(cls) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

804 self._mnff = Graph._MembersNeededForFinalize(self, graph) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

805 return self 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E

806  

807 def close(self) -> None: 

808 """Destroy the graph.""" 

809 self._mnff.close() 1PE

810  

811 @property 

812 def handle(self) -> driver.CUgraphExec: 

813 """Return the underlying ``CUgraphExec`` object. 

814  

815 .. caution:: 

816  

817 This handle is a Python object. To get the memory address of the underlying C 

818 handle, call ``int()`` on the returned object. 

819  

820 """ 

821 return self._mnff.graph 

822  

823 def update(self, source: "GraphBuilder | GraphDefinition") -> None: 

824 """Update the graph using a new graph definition. 

825  

826 The topology of the provided source must be identical to this graph. 

827  

828 Parameters 

829 ---------- 

830 source : :obj:`~graph.GraphBuilder` or :obj:`~graph.GraphDefinition` 

831 The graph definition to update from. A GraphBuilder must have 

832 finished building. 

833  

834 """ 

835 from cuda.core.graph import GraphDefinition 1vI$FNM

836  

837 cdef cydriver.CUgraph cu_graph 

838 cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph) 1vI$FNM

839  

840 if isinstance(source, GraphBuilder): 1vI$FNM

841 if not source._building_ended: 1vIFN

842 raise ValueError("Graph has not finished building.") 1N

843 cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph) 1vIF

844 elif isinstance(source, GraphDefinition): 1$M

845 cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle) 1$

846 else: 

847 raise TypeError( 1M

848 f"expected GraphBuilder or GraphDefinition, got {type(source).__name__}") 1M

849  

850 cdef cydriver.CUgraphExecUpdateResultInfo result_info 

851 cdef cydriver.CUresult err 

852 with nogil: 1vI$F

853 err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) 1vI$F

854 if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: 1vI$F

855 reason = driver.CUgraphExecUpdateResult(result_info.result) 1F

856 msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" 1F

857 raise CUDAError(msg) 1F

858 HANDLE_RETURN(err) 1vI$

859  

860 def upload(self, stream: Stream) -> None: 

861 """Uploads the graph in a stream. 

862  

863 Parameters 

864 ---------- 

865 stream : :obj:`~_stream.Stream` 

866 The stream in which to upload the graph 

867  

868 """ 

869 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 2A O [ % ] ' ^ _ ` { ) | * , - } . : ~ ; abbbcb2 Q W 3 R X L J K 4 S Y 5 T Z

870  

871 def launch(self, stream: Stream) -> None: 

872 """Launches the graph in a stream. 

873  

874 Parameters 

875 ---------- 

876 stream : :obj:`~_stream.Stream` 

877 The stream in which to launch the graph. 

878  

879 """ 

880 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 2G H C A V O b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = dbabbbebfbgbhbcbibjbkblbmbnb2 Q W 3 R X L J K 4 S Y 5 T Z v I $