Coverage for cuda / core / _graph.py: 87.43%

334 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-08 01:07 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5from __future__ import annotations 

6 

7import weakref 

8from dataclasses import dataclass 

9from typing import TYPE_CHECKING 

10 

11if TYPE_CHECKING: 

12 from cuda.core._stream import Stream 

13 

14from cuda.core._utils.cuda_utils import ( 

15 driver, 

16 get_binding_version, 

17 handle_return, 

18) 

19 

20_inited = False 

21_driver_ver = None 

22 

23 

24def _lazy_init(): 

25 global _inited 

26 if _inited: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

27 return 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

28 

29 global _py_major_minor, _driver_ver 

30 # binding availability depends on cuda-python version 

31 _py_major_minor = get_binding_version() 1D

32 _driver_ver = handle_return(driver.cuDriverGetVersion()) 1D

33 _inited = True 1D

34 

35 

36@dataclass 

37class GraphDebugPrintOptions: 

38 """Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()` 

39 

40 Attributes 

41 ---------- 

42 verbose : bool 

43 Output all debug data as if every debug flag is enabled (Default to False) 

44 runtime_types : bool 

45 Use CUDA Runtime structures for output (Default to False) 

46 kernel_node_params : bool 

47 Adds kernel parameter values to output (Default to False) 

48 memcpy_node_params : bool 

49 Adds memcpy parameter values to output (Default to False) 

50 memset_node_params : bool 

51 Adds memset parameter values to output (Default to False) 

52 host_node_params : bool 

53 Adds host parameter values to output (Default to False) 

54 event_node_params : bool 

55 Adds event parameter values to output (Default to False) 

56 ext_semas_signal_node_params : bool 

57 Adds external semaphore signal parameter values to output (Default to False) 

58 ext_semas_wait_node_params : bool 

59 Adds external semaphore wait parameter values to output (Default to False) 

60 kernel_node_attributes : bool 

61 Adds kernel node attributes to output (Default to False) 

62 handles : bool 

63 Adds node handles and every kernel function handle to output (Default to False) 

64 mem_alloc_node_params : bool 

65 Adds memory alloc parameter values to output (Default to False) 

66 mem_free_node_params : bool 

67 Adds memory free parameter values to output (Default to False) 

68 batch_mem_op_node_params : bool 

69 Adds batch mem op parameter values to output (Default to False) 

70 extra_topo_info : bool 

71 Adds edge numbering information (Default to False) 

72 conditional_node_params : bool 

73 Adds conditional node parameter values to output (Default to False) 

74 

75 """ 

76 

77 verbose: bool = False 

78 runtime_types: bool = False 

79 kernel_node_params: bool = False 

80 memcpy_node_params: bool = False 

81 memset_node_params: bool = False 

82 host_node_params: bool = False 

83 event_node_params: bool = False 

84 ext_semas_signal_node_params: bool = False 

85 ext_semas_wait_node_params: bool = False 

86 kernel_node_attributes: bool = False 

87 handles: bool = False 

88 mem_alloc_node_params: bool = False 

89 mem_free_node_params: bool = False 

90 batch_mem_op_node_params: bool = False 

91 extra_topo_info: bool = False 

92 conditional_node_params: bool = False 

93 

94 

95@dataclass 

96class GraphCompleteOptions: 

97 """Customizable options for :obj:`_graph.GraphBuilder.complete()` 

98 

99 Attributes 

100 ---------- 

101 auto_free_on_launch : bool, optional 

102 Automatically free memory allocated in a graph before relaunching. (Default to False) 

103 upload_stream : Stream, optional 

104 Stream to use to automatically upload the graph after completion. (Default to None) 

105 device_launch : bool, optional 

106 Configure the graph to be launchable from the device. This flag can only 

107 be used on platforms which support unified addressing. This flag cannot be 

108 used in conjunction with auto_free_on_launch. (Default to False) 

109 use_node_priority : bool, optional 

110 Run the graph using the per-node priority attributes rather than the 

111 priority of the stream it is launched into. (Default to False) 

112 

113 """ 

114 

115 auto_free_on_launch: bool = False 

116 upload_stream: Stream | None = None 

117 device_launch: bool = False 

118 use_node_priority: bool = False 

119 

120 

121class GraphBuilder: 

122 """Represents a graph under construction. 

123 

124 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

125 them with a specified dependency tree. It speeds up the workflow by combining the 

126 driver activities associated with CUDA kernel launches and CUDA API calls. 

127 

128 Directly creating a :obj:`~_graph.GraphBuilder` is not supported due 

129 to ambiguity. New graph builders should instead be created through a 

130 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object. 

131 

132 """ 

133 

134 class _MembersNeededForFinalize: 

135 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") 

136 

137 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required): 

138 self.stream = stream_obj 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

139 self.is_stream_owner = is_stream_owner 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

140 self.graph = None 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

141 self.conditional_graph = conditional_graph 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

142 self.is_join_required = is_join_required 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

143 weakref.finalize(graph_builder_obj, self.close) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

144 

145 def close(self): 

146 if self.stream: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

147 if not self.is_join_required: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

148 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

149 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

150 # Note how this condition only occures for the primary graph builder 

151 # This is because calling cuStreamEndCapture streams that were split off of the primary 

152 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. 

153 # Therefore, it is currently a requirement that users join all split graph builders 

154 # before a graph builder can be clearly destroyed. 

155 handle_return(driver.cuStreamEndCapture(self.stream.handle)) 

156 if self.is_stream_owner: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

157 self.stream.close() 1DIwOv5xJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

158 self.stream = None 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

159 if self.graph: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

160 handle_return(driver.cuGraphDestroy(self.graph)) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

161 self.graph = None 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

162 self.conditional_graph = None 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

163 

164 __slots__ = ("__weakref__", "_building_ended", "_mnff") 

165 

166 def __init__(self): 

167 raise NotImplementedError( 

168 "directly creating a Graph object can be ambiguous. Please either " 

169 "call Device.create_graph_builder() or stream.create_graph_builder()" 

170 ) 

171 

172 @classmethod 

173 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): 

174 self = cls.__new__(cls) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

175 _lazy_init() 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

176 self._mnff = GraphBuilder._MembersNeededForFinalize( 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

177 self, stream, is_stream_owner, conditional_graph, is_join_required 

178 ) 

179 

180 self._building_ended = False 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

181 return self 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

182 

183 @property 

184 def stream(self) -> Stream: 

185 """Returns the stream associated with the graph builder.""" 

186 return self._mnff.stream 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

187 

188 @property 

189 def is_join_required(self) -> bool: 

190 """Returns True if this graph builder must be joined before building is ended.""" 

191 return self._mnff.is_join_required 1vxbcdefghijklmnopqrstua

192 

193 def begin_building(self, mode="relaxed") -> GraphBuilder: 

194 """Begins the building process. 

195 

196 Build `mode` for controlling interaction with other API calls must be one of the following: 

197 

198 - `global` : Prohibit potentially unsafe operations across all streams in the process. 

199 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread. 

200 - `relaxed` : The local thread is not prohibited from potentially unsafe operations. 

201 

202 Parameters 

203 ---------- 

204 mode : str, optional 

205 Build mode to control the interaction with other API calls that are porentially unsafe. 

206 Default set to use relaxed. 

207 

208 """ 

209 if self._building_ended: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

210 raise RuntimeError("Cannot resume building after building has ended.") 1P

211 if mode not in ("global", "thread_local", "relaxed"): 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

212 raise ValueError(f"Unsupported build mode: {mode}") 14

213 if mode == "global": 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

214 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 123WXHYZ4

215 elif mode == "thread_local": 1DIwOv5xPJU0V1KQLRFGMSNTbcdefghijklmnopqrstuyzAB4Ea

216 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 101QRGST4

217 elif mode == "relaxed": 1DIwOv5xPJUVKLFMNbcdefghijklmnopqrstuyzAB4Ea

218 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1DIwOv5xPJUVKLFMNbcdefghijklmnopqrstuyzAB4Ea

219 else: 

220 raise ValueError(f"Unsupported build mode: {mode}") 

221 

222 if self._mnff.conditional_graph: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

223 handle_return( 1wbcdefghijklmnopqrstuyzABa

224 driver.cuStreamBeginCaptureToGraph( 

225 self._mnff.stream.handle, 

226 self._mnff.conditional_graph, 

227 None, # dependencies 

228 None, # dependencyData 

229 0, # numDependencies 

230 capture_mode, 

231 ) 

232 ) 

233 else: 

234 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

235 return self 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

236 

237 @property 

238 def is_building(self) -> bool: 

239 """Returns True if the graph builder is currently building.""" 

240 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

241 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

242 return False 15

243 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

244 return True 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

245 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: 

246 raise RuntimeError( 

247 "Build process encountered an error and has been invalidated. Build process must now be ended." 

248 ) 

249 else: 

250 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") 

251 

252 def end_building(self) -> GraphBuilder: 

253 """Ends the building process.""" 

254 if not self.is_building: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

255 raise RuntimeError("Graph builder is not building.") 

256 if self._mnff.conditional_graph: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

257 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1wbcdefghijklmnopqrstuyzABa

258 else: 

259 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

260 

261 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to 

262 # resume the build process after the first call to end_building() 

263 self._building_ended = True 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

264 return self 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea

265 

266 def complete(self, options: GraphCompleteOptions | None = None) -> Graph: 

267 """Completes the graph builder and returns the built :obj:`~_graph.Graph` object. 

268 

269 Parameters 

270 ---------- 

271 options : :obj:`~_graph.GraphCompleteOptions`, optional 

272 Customizable dataclass for the graph builder completion options. 

273 

274 Returns 

275 ------- 

276 graph : :obj:`~_graph.Graph` 

277 The newly built graph. 

278 

279 """ 

280 if not self._building_ended: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

281 raise RuntimeError("Graph has not finished building.") 1O

282 

283 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

284 flags = 0 

285 if options: 

286 if options.auto_free_on_launch: 

287 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 

288 if options.use_node_priority: 

289 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 

290 return Graph._init(handle_return(driver.cuGraphInstantiateWithFlags(self._mnff.graph, flags))) 

291 

292 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

293 if options: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

294 flags = 0 1HFGE

295 if options.auto_free_on_launch: 1HFGE

296 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1HFGE

297 if options.upload_stream: 1HFGE

298 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1E

299 params.hUploadStream = options.upload_stream.handle 1E

300 if options.device_launch: 1HFGE

301 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1E

302 if options.use_node_priority: 1HFGE

303 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 1E

304 params.flags = flags 1HFGE

305 

306 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(self._mnff.graph, params))) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

307 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

308 # NOTE: Should never get here since the handle_return should have caught this case 

309 raise RuntimeError( 

310 "Instantiation failed for an unexpected reason which is described in the return value of the function." 

311 ) 

312 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

313 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.") 

314 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

315 raise RuntimeError( 

316 "Instantiation for device launch failed because the graph contained an unsupported operation." 

317 ) 

318 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

319 raise RuntimeError( 

320 "Instantiation for device launch failed due to the nodes belonging to different contexts." 

321 ) 

322 elif ( 

323 _py_major_minor >= (12, 8) 

324 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 

325 ): 

326 raise RuntimeError("One or more conditional handles are not associated with conditional builders.") 

327 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

328 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") 

329 return graph 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

330 

331 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): 

332 """Generates a DOT debug file for the graph builder. 

333 

334 Parameters 

335 ---------- 

336 path : str 

337 File path to use for writting debug DOT output 

338 options : :obj:`~_graph.GraphDebugPrintOptions`, optional 

339 Customizable dataclass for the debug print options. 

340 

341 """ 

342 if not self._building_ended: 1a

343 raise RuntimeError("Graph has not finished building.") 

344 flags = 0 1a

345 if options: 1a

346 if options.verbose: 1a

347 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 1a

348 if options.runtime_types: 1a

349 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a

350 if options.kernel_node_params: 1a

351 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a

352 if options.memcpy_node_params: 1a

353 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a

354 if options.memset_node_params: 1a

355 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a

356 if options.host_node_params: 1a

357 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a

358 if options.event_node_params: 1a

359 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a

360 if options.ext_semas_signal_node_params: 1a

361 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a

362 if options.ext_semas_wait_node_params: 1a

363 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a

364 if options.kernel_node_attributes: 1a

365 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a

366 if options.handles: 1a

367 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 1a

368 if options.mem_alloc_node_params: 1a

369 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a

370 if options.mem_free_node_params: 1a

371 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a

372 if options.batch_mem_op_node_params: 1a

373 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a

374 if options.extra_topo_info: 1a

375 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a

376 if options.conditional_node_params: 1a

377 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a

378 

379 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) 1a

380 

381 def split(self, count: int) -> tuple[GraphBuilder, ...]: 

382 """Splits the original graph builder into multiple graph builders. 

383 

384 The new builders inherit work dependencies from the original builder. 

385 The original builder is reused for the split and is returned first in the tuple. 

386 

387 Parameters 

388 ---------- 

389 count : int 

390 The number of graph builders to split the graph builder into. 

391 

392 Returns 

393 ------- 

394 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] 

395 A tuple of split graph builders. The first graph builder in the tuple 

396 is always the original graph builder. 

397 

398 """ 

399 if count < 2: 1vxbcdefghijklmnopqrstua

400 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1v

401 

402 event = self._mnff.stream.record() 1vxbcdefghijklmnopqrstua

403 result = [self] 1vxbcdefghijklmnopqrstua

404 for i in range(count - 1): 1vxbcdefghijklmnopqrstua

405 stream = self._mnff.stream.device.create_stream() 1vxbcdefghijklmnopqrstua

406 stream.wait(event) 1vxbcdefghijklmnopqrstua

407 result.append( 1vxbcdefghijklmnopqrstua

408 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) 

409 ) 

410 event.close() 1vxbcdefghijklmnopqrstua

411 return result 1vxbcdefghijklmnopqrstua

412 

413 @staticmethod 

414 def join(*graph_builders) -> GraphBuilder: 

415 """Joins multiple graph builders into a single graph builder. 

416 

417 The returned builder inherits work dependencies from the provided builders. 

418 

419 Parameters 

420 ---------- 

421 *graph_builders : :obj:`~_graph.GraphBuilder` 

422 The graph builders to join. 

423 

424 Returns 

425 ------- 

426 graph_builder : :obj:`~_graph.GraphBuilder` 

427 The newly joined graph builder. 

428 

429 """ 

430 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1vxbcdefghijklmnopqrstua

431 raise TypeError("All arguments must be GraphBuilder instances") 

432 if len(graph_builders) < 2: 1vxbcdefghijklmnopqrstua

433 raise ValueError("Must join with at least two graph builders") 1v

434 

435 # Discover the root builder others should join 

436 root_idx = 0 1vxbcdefghijklmnopqrstua

437 for i, builder in enumerate(graph_builders): 1vxbcdefghijklmnopqrstua

438 if not builder.is_join_required: 1vxbcdefghijklmnopqrstua

439 root_idx = i 1vxbcdefghijklmnopqrstua

440 break 1vxbcdefghijklmnopqrstua

441 

442 # Join all onto the root builder 

443 root_bdr = graph_builders[root_idx] 1vxbcdefghijklmnopqrstua

444 for idx, builder in enumerate(graph_builders): 1vxbcdefghijklmnopqrstua

445 if idx == root_idx: 1vxbcdefghijklmnopqrstua

446 continue 1vxbcdefghijklmnopqrstua

447 root_bdr.stream.wait(builder.stream) 1vxbcdefghijklmnopqrstua

448 builder.close() 1vxbcdefghijklmnopqrstua

449 

450 return root_bdr 1vxbcdefghijklmnopqrstua

451 

452 def __cuda_stream__(self) -> tuple[int, int]: 

453 """Return an instance of a __cuda_stream__ protocol.""" 

454 return self.stream.__cuda_stream__() 

455 

456 def _get_conditional_context(self) -> driver.CUcontext: 

457 return self._mnff.stream.context.handle 1wbcdefghijklmnopqrstuyzABa

458 

459 def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle: 

460 """Creates a conditional handle for the graph builder. 

461 

462 Parameters 

463 ---------- 

464 default_value : int, optional 

465 The default value to assign to the conditional handle. 

466 

467 Returns 

468 ------- 

469 handle : driver.CUgraphConditionalHandle 

470 The newly created conditional handle. 

471 

472 """ 

473 if _driver_ver < 12030: 1wbcdefghijklmnopqrstuyzABa

474 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles") 

475 if _py_major_minor < (12, 3): 1wbcdefghijklmnopqrstuyzABa

476 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles") 

477 if default_value is not None: 1wbcdefghijklmnopqrstuyzABa

478 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wyzAB

479 else: 

480 default_value = 0 1bcdefghijklmnopqrstua

481 flags = 0 1bcdefghijklmnopqrstua

482 

483 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1wbcdefghijklmnopqrstuyzABa

484 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1wbcdefghijklmnopqrstuyzABa

485 raise RuntimeError("Cannot create a conditional handle when graph is not being built") 

486 

487 return handle_return( 1wbcdefghijklmnopqrstuyzABa

488 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 

489 ) 

490 

491 def _cond_with_params(self, node_params) -> GraphBuilder: 

492 # Get current capture info to ensure we're in a valid state 

493 status, _, graph, *deps_info, num_dependencies = handle_return( 1wbcdefghijklmnopqrstuyzABa

494 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) 

495 ) 

496 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1wbcdefghijklmnopqrstuyzABa

497 raise RuntimeError("Cannot add conditional node when not actively capturing") 

498 

499 # Add the conditional node to the graph 

500 deps_info_update = [ 1wbcdefghijklmnopqrstuyzABa

501 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 

502 ] + [None] * (len(deps_info) - 1) 

503 

504 # Update the stream's capture dependencies 

505 handle_return( 1wbcdefghijklmnopqrstuyzABa

506 driver.cuStreamUpdateCaptureDependencies( 

507 self._mnff.stream.handle, 

508 *deps_info_update, # dependencies, edgeData 

509 1, # numDependencies 

510 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 

511 ) 

512 ) 

513 

514 # Create new graph builders for each condition 

515 return tuple( 1wbcdefghijklmnopqrstuyzABa

516 [ 

517 GraphBuilder._init( 

518 stream=self._mnff.stream.device.create_stream(), 

519 is_stream_owner=True, 

520 conditional_graph=node_params.conditional.phGraph_out[i], 

521 is_join_required=False, 

522 ) 

523 for i in range(node_params.conditional.size) 

524 ] 

525 ) 

526 

527 def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: 

528 """Adds an if condition branch and returns a new graph builder for it. 

529 

530 The resulting if graph will only execute the branch if the conditional 

531 handle evaluates to true at runtime. 

532 

533 The new builder inherits work dependencies from the original builder. 

534 

535 Parameters 

536 ---------- 

537 handle : driver.CUgraphConditionalHandle 

538 The handle to use for the if conditional. 

539 

540 Returns 

541 ------- 

542 graph_builder : :obj:`~_graph.GraphBuilder` 

543 The newly created conditional graph builder. 

544 

545 """ 

546 if _driver_ver < 12030: 1bcdefghia

547 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if") 

548 if _py_major_minor < (12, 3): 1bcdefghia

549 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if") 

550 node_params = driver.CUgraphNodeParams() 1bcdefghia

551 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia

552 node_params.conditional.handle = handle 1bcdefghia

553 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia

554 node_params.conditional.size = 1 1bcdefghia

555 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia

556 return self._cond_with_params(node_params)[0] 1bcdefghia

557 

558 def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]: 

559 """Adds an if-else condition branch and returns new graph builders for both branches. 

560 

561 The resulting if graph will execute the branch if the conditional handle 

562 evaluates to true at runtime, otherwise the else branch will execute. 

563 

564 The new builders inherit work dependencies from the original builder. 

565 

566 Parameters 

567 ---------- 

568 handle : driver.CUgraphConditionalHandle 

569 The handle to use for the if-else conditional. 

570 

571 Returns 

572 ------- 

573 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, :obj:`~_graph.GraphBuilder`] 

574 A tuple of two new graph builders, one for the if branch and one for the else branch. 

575 

576 """ 

577 if _driver_ver < 12080: 1jklmnopq

578 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else") 

579 if _py_major_minor < (12, 8): 1jklmnopq

580 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else") 

581 node_params = driver.CUgraphNodeParams() 1jklmnopq

582 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq

583 node_params.conditional.handle = handle 1jklmnopq

584 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq

585 node_params.conditional.size = 2 1jklmnopq

586 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq

587 return self._cond_with_params(node_params) 1jklmnopq

588 

589 def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]: 

590 """Adds a switch condition branch and returns new graph builders for all cases. 

591 

592 The resulting switch graph will execute the branch that matches the 

593 case index of the conditional handle at runtime. If no match is found, no branch 

594 will be executed. 

595 

596 The new builders inherit work dependencies from the original builder. 

597 

598 Parameters 

599 ---------- 

600 handle : driver.CUgraphConditionalHandle 

601 The handle to use for the switch conditional. 

602 count : int 

603 The number of cases to add to the switch conditional. 

604 

605 Returns 

606 ------- 

607 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] 

608 A tuple of new graph builders, one for each branch. 

609 

610 """ 

611 if _driver_ver < 12080: 1wrstu

612 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch") 

613 if _py_major_minor < (12, 8): 1wrstu

614 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch") 

615 node_params = driver.CUgraphNodeParams() 1wrstu

616 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wrstu

617 node_params.conditional.handle = handle 1wrstu

618 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1wrstu

619 node_params.conditional.size = count 1wrstu

620 node_params.conditional.ctx = self._get_conditional_context() 1wrstu

621 return self._cond_with_params(node_params) 1wrstu

622 

623 def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: 

624 """Adds a while loop and returns a new graph builder for it. 

625 

626 The resulting while loop graph will execute the branch repeatedly at runtime 

627 until the conditional handle evaluates to false. 

628 

629 The new builder inherits work dependencies from the original builder. 

630 

631 Parameters 

632 ---------- 

633 handle : driver.CUgraphConditionalHandle 

634 The handle to use for the while loop. 

635 

636 Returns 

637 ------- 

638 graph_builder : :obj:`~_graph.GraphBuilder` 

639 The newly created while loop graph builder. 

640 

641 """ 

642 if _driver_ver < 12030: 1yzAB

643 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop") 

644 if _py_major_minor < (12, 3): 1yzAB

645 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop") 

646 node_params = driver.CUgraphNodeParams() 1yzAB

647 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1yzAB

648 node_params.conditional.handle = handle 1yzAB

649 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1yzAB

650 node_params.conditional.size = 1 1yzAB

651 node_params.conditional.ctx = self._get_conditional_context() 1yzAB

652 return self._cond_with_params(node_params)[0] 1yzAB

653 

654 def close(self): 

655 """Destroy the graph builder. 

656 

657 Closes the associated stream if we own it. Borrowed stream 

658 object will instead have their references released. 

659 

660 """ 

661 self._mnff.close() 1Ivxbcdefghijklmnopqrstua

662 

663 def add_child(self, child_graph: GraphBuilder): 

664 """Adds the child :obj:`~_graph.GraphBuilder` builder into self. 

665 

666 The child graph builder will be added as a child node to the parent graph builder. 

667 

668 Parameters 

669 ---------- 

670 child_graph : :obj:`~_graph.GraphBuilder` 

671 The child graph builder. Must have finished building. 

672 """ 

673 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): 1D

674 raise NotImplementedError( 

675 f"Launching child graphs is not implemented for versions older than CUDA 12." 

676 f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}" 

677 ) 

678 

679 if not child_graph._building_ended: 1D

680 raise ValueError("Child graph has not finished building.") 

681 

682 if not self.is_building: 1D

683 raise ValueError("Parent graph is not being built.") 

684 

685 stream_handle = self._mnff.stream.handle 1D

686 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1D

687 driver.cuStreamGetCaptureInfo(stream_handle) 

688 ) 

689 

690 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159 

691 # for rationale 

692 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1D

693 deps_info_update = [ 1D

694 [ 

695 handle_return( 

696 driver.cuGraphAddChildGraphNode( 

697 graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph 

698 ) 

699 ) 

700 ] 

701 ] + [None] * (len(deps_info_out) - 1) 

702 handle_return( 1D

703 driver.cuStreamUpdateCaptureDependencies( 

704 stream_handle, 

705 *deps_info_update, # dependencies, edgeData 

706 1, 

707 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 

708 ) 

709 ) 

710 

711 

712class Graph: 

713 """Represents an executable graph. 

714 

715 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

716 them with a specified dependency tree. It speeds up the workflow by combining the 

717 driver activities associated with CUDA kernel launches and CUDA API calls. 

718 

719 Graphs must be built using a :obj:`~_graph.GraphBuilder` object. 

720 

721 """ 

722 

723 class _MembersNeededForFinalize: 

724 __slots__ = "graph" 

725 

726 def __init__(self, graph_obj, graph): 

727 self.graph = graph 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

728 weakref.finalize(graph_obj, self.close) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

729 

730 def close(self): 

731 if self.graph: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

732 handle_return(driver.cuGraphExecDestroy(self.graph)) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

733 self.graph = None 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

734 

735 __slots__ = ("__weakref__", "_mnff") 

736 

737 def __init__(self): 

738 raise RuntimeError("directly constructing a Graph instance is not supported") 

739 

740 @classmethod 

741 def _init(cls, graph): 

742 self = cls.__new__(cls) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

743 self._mnff = Graph._MembersNeededForFinalize(self, graph) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

744 return self 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE

745 

746 def close(self): 

747 """Destroy the graph.""" 

748 self._mnff.close() 1IE

749 

750 @property 

751 def handle(self) -> driver.CUgraphExec: 

752 """Return the underlying ``CUgraphExec`` object. 

753 

754 .. caution:: 

755 

756 This handle is a Python object. To get the memory address of the underlying C 

757 handle, call ``int()`` on the returned object. 

758 

759 """ 

760 return self._mnff.graph 

761 

762 def update(self, builder: GraphBuilder): 

763 """Update the graph using new build configuration from the builder. 

764 

765 The topology of the provided builder must be identical to this graph. 

766 

767 Parameters 

768 ---------- 

769 builder : :obj:`~_graph.GraphBuilder` 

770 The builder to update the graph with. 

771 

772 """ 

773 if not builder._building_ended: 1w

774 raise ValueError("Graph has not finished building.") 

775 

776 # Update the graph with the new nodes from the builder 

777 exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph)) 1w

778 if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS: 1w

779 raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}") 

780 

781 def upload(self, stream: Stream): 

782 """Uploads the graph in a stream. 

783 

784 Parameters 

785 ---------- 

786 stream : :obj:`~_stream.Stream` 

787 The stream in which to upload the graph 

788 

789 """ 

790 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 1vJWKQXLRHFGYMSZNT

791 

792 def launch(self, stream: Stream): 

793 """Launches the graph in a stream. 

794 

795 Parameters 

796 ---------- 

797 stream : :obj:`~_stream.Stream` 

798 The stream in which to launch the graph 

799 

800 """ 

801 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 1DwvPJWKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB