Coverage for cuda / core / experimental / _graph.py: 88%

331 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-10 01:19 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5from __future__ import annotations 

6 

7import weakref 

8from dataclasses import dataclass 

9from typing import TYPE_CHECKING 

10 

11if TYPE_CHECKING: 

12 from cuda.core.experimental._stream import Stream 

13from cuda.core.experimental._utils.cuda_utils import ( 

14 driver, 

15 get_binding_version, 

16 handle_return, 

17) 

18 

19_inited = False 

20_driver_ver = None 

21 

22 

23def _lazy_init(): 

24 global _inited 

25 if _inited: 

26 return 

27 

28 global _py_major_minor, _driver_ver 

29 # binding availability depends on cuda-python version 

30 _py_major_minor = get_binding_version() 

31 _driver_ver = handle_return(driver.cuDriverGetVersion()) 

32 _inited = True 

33 

34 

35@dataclass 

36class GraphDebugPrintOptions: 

37 """Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()` 

38 

39 Attributes 

40 ---------- 

41 verbose : bool 

42 Output all debug data as if every debug flag is enabled (Default to False) 

43 runtime_types : bool 

44 Use CUDA Runtime structures for output (Default to False) 

45 kernel_node_params : bool 

46 Adds kernel parameter values to output (Default to False) 

47 memcpy_node_params : bool 

48 Adds memcpy parameter values to output (Default to False) 

49 memset_node_params : bool 

50 Adds memset parameter values to output (Default to False) 

51 host_node_params : bool 

52 Adds host parameter values to output (Default to False) 

53 event_node_params : bool 

54 Adds event parameter values to output (Default to False) 

55 ext_semas_signal_node_params : bool 

56 Adds external semaphore signal parameter values to output (Default to False) 

57 ext_semas_wait_node_params : bool 

58 Adds external semaphore wait parameter values to output (Default to False) 

59 kernel_node_attributes : bool 

60 Adds kernel node attributes to output (Default to False) 

61 handles : bool 

62 Adds node handles and every kernel function handle to output (Default to False) 

63 mem_alloc_node_params : bool 

64 Adds memory alloc parameter values to output (Default to False) 

65 mem_free_node_params : bool 

66 Adds memory free parameter values to output (Default to False) 

67 batch_mem_op_node_params : bool 

68 Adds batch mem op parameter values to output (Default to False) 

69 extra_topo_info : bool 

70 Adds edge numbering information (Default to False) 

71 conditional_node_params : bool 

72 Adds conditional node parameter values to output (Default to False) 

73 

74 """ 

75 

76 verbose: bool = False 

77 runtime_types: bool = False 

78 kernel_node_params: bool = False 

79 memcpy_node_params: bool = False 

80 memset_node_params: bool = False 

81 host_node_params: bool = False 

82 event_node_params: bool = False 

83 ext_semas_signal_node_params: bool = False 

84 ext_semas_wait_node_params: bool = False 

85 kernel_node_attributes: bool = False 

86 handles: bool = False 

87 mem_alloc_node_params: bool = False 

88 mem_free_node_params: bool = False 

89 batch_mem_op_node_params: bool = False 

90 extra_topo_info: bool = False 

91 conditional_node_params: bool = False 

92 

93 

94@dataclass 

95class GraphCompleteOptions: 

96 """Customizable options for :obj:`_graph.GraphBuilder.complete()` 

97 

98 Attributes 

99 ---------- 

100 auto_free_on_launch : bool, optional 

101 Automatically free memory allocated in a graph before relaunching. (Default to False) 

102 upload_stream : Stream, optional 

103 Stream to use to automatically upload the graph after completion. (Default to None) 

104 device_launch : bool, optional 

105 Configure the graph to be launchable from the device. This flag can only 

106 be used on platforms which support unified addressing. This flag cannot be 

107 used in conjunction with auto_free_on_launch. (Default to False) 

108 use_node_priority : bool, optional 

109 Run the graph using the per-node priority attributes rather than the 

110 priority of the stream it is launched into. (Default to False) 

111 

112 """ 

113 

114 auto_free_on_launch: bool = False 

115 upload_stream: Stream | None = None 

116 device_launch: bool = False 

117 use_node_priority: bool = False 

118 

119 

120class GraphBuilder: 

121 """Represents a graph under construction. 

122 

123 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

124 them with a specified dependency tree. It speeds up the workflow by combining the 

125 driver activities associated with CUDA kernel launches and CUDA API calls. 

126 

127 Directly creating a :obj:`~_graph.GraphBuilder` is not supported due 

128 to ambiguity. New graph builders should instead be created through a 

129 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object. 

130 

131 """ 

132 

133 class _MembersNeededForFinalize: 

134 __slots__ = ("stream", "is_stream_owner", "graph", "conditional_graph", "is_join_required") 

135 

136 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required): 

137 self.stream = stream_obj 

138 self.is_stream_owner = is_stream_owner 

139 self.graph = None 

140 self.conditional_graph = conditional_graph 

141 self.is_join_required = is_join_required 

142 weakref.finalize(graph_builder_obj, self.close) 

143 

144 def close(self): 

145 if self.stream: 

146 if not self.is_join_required: 

147 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 

148 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 

149 # Note how this condition only occures for the primary graph builder 

150 # This is because calling cuStreamEndCapture streams that were split off of the primary 

151 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. 

152 # Therefore, it is currently a requirement that users join all split graph builders 

153 # before a graph builder can be clearly destroyed. 

154 handle_return(driver.cuStreamEndCapture(self.stream.handle)) 

155 if self.is_stream_owner: 

156 self.stream.close() 

157 self.stream = None 

158 if self.graph: 

159 handle_return(driver.cuGraphDestroy(self.graph)) 

160 self.graph = None 

161 self.conditional_graph = None 

162 

163 __slots__ = ("__weakref__", "_mnff", "_building_ended") 

164 

165 def __init__(self): 

166 raise NotImplementedError( 

167 "directly creating a Graph object can be ambiguous. Please either " 

168 "call Device.create_graph_builder() or stream.create_graph_builder()" 

169 ) 

170 

171 @classmethod 

172 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): 

173 self = cls.__new__(cls) 

174 _lazy_init() 

175 self._mnff = GraphBuilder._MembersNeededForFinalize( 

176 self, stream, is_stream_owner, conditional_graph, is_join_required 

177 ) 

178 

179 self._building_ended = False 

180 return self 

181 

182 @property 

183 def stream(self) -> Stream: 

184 """Returns the stream associated with the graph builder.""" 

185 return self._mnff.stream 

186 

187 @property 

188 def is_join_required(self) -> bool: 

189 """Returns True if this graph builder must be joined before building is ended.""" 

190 return self._mnff.is_join_required 

191 

192 def begin_building(self, mode="relaxed") -> GraphBuilder: 

193 """Begins the building process. 

194 

195 Build `mode` for controlling interaction with other API calls must be one of the following: 

196 

197 - `global` : Prohibit potentially unsafe operations across all streams in the process. 

198 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread. 

199 - `relaxed` : The local thread is not prohibited from potentially unsafe operations. 

200 

201 Parameters 

202 ---------- 

203 mode : str, optional 

204 Build mode to control the interaction with other API calls that are porentially unsafe. 

205 Default set to use relaxed. 

206 

207 """ 

208 if self._building_ended: 

209 raise RuntimeError("Cannot resume building after building has ended.") 

210 if mode not in ("global", "thread_local", "relaxed"): 

211 raise ValueError(f"Unsupported build mode: {mode}") 

212 if mode == "global": 

213 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 

214 elif mode == "thread_local": 

215 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 

216 elif mode == "relaxed": 

217 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 

218 else: 

219 raise ValueError(f"Unsupported build mode: {mode}") 

220 

221 if self._mnff.conditional_graph: 

222 handle_return( 

223 driver.cuStreamBeginCaptureToGraph( 

224 self._mnff.stream.handle, 

225 self._mnff.conditional_graph, 

226 None, # dependencies 

227 None, # dependencyData 

228 0, # numDependencies 

229 capture_mode, 

230 ) 

231 ) 

232 else: 

233 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 

234 return self 

235 

236 @property 

237 def is_building(self) -> bool: 

238 """Returns True if the graph builder is currently building.""" 

239 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 

240 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 

241 return False 

242 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 

243 return True 

244 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: 

245 raise RuntimeError( 

246 "Build process encountered an error and has been invalidated. Build process must now be ended." 

247 ) 

248 else: 

249 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") 

250 

251 def end_building(self) -> GraphBuilder: 

252 """Ends the building process.""" 

253 if not self.is_building: 

254 raise RuntimeError("Graph builder is not building.") 

255 if self._mnff.conditional_graph: 

256 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 

257 else: 

258 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 

259 

260 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to 

261 # resume the build process after the first call to end_building() 

262 self._building_ended = True 

263 return self 

264 

265 def complete(self, options: GraphCompleteOptions | None = None) -> Graph: 

266 """Completes the graph builder and returns the built :obj:`~_graph.Graph` object. 

267 

268 Parameters 

269 ---------- 

270 options : :obj:`~_graph.GraphCompleteOptions`, optional 

271 Customizable dataclass for the graph builder completion options. 

272 

273 Returns 

274 ------- 

275 graph : :obj:`~_graph.Graph` 

276 The newly built graph. 

277 

278 """ 

279 if not self._building_ended: 

280 raise RuntimeError("Graph has not finished building.") 

281 

282 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): 

283 flags = 0 

284 if options: 

285 if options.auto_free_on_launch: 

286 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 

287 if options.use_node_priority: 

288 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 

289 return Graph._init(handle_return(driver.cuGraphInstantiateWithFlags(self._mnff.graph, flags))) 

290 

291 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 

292 if options: 

293 flags = 0 

294 if options.auto_free_on_launch: 

295 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 

296 if options.upload_stream: 

297 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 

298 params.hUploadStream = options.upload_stream.handle 

299 if options.device_launch: 

300 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 

301 if options.use_node_priority: 

302 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 

303 params.flags = flags 

304 

305 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(self._mnff.graph, params))) 

306 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 

307 # NOTE: Should never get here since the handle_return should have caught this case 

308 raise RuntimeError( 

309 "Instantiation failed for an unexpected reason which is described in the return value of the function." 

310 ) 

311 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 

312 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.") 

313 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 

314 raise RuntimeError( 

315 "Instantiation for device launch failed because the graph contained an unsupported operation." 

316 ) 

317 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 

318 raise RuntimeError( 

319 "Instantiation for device launch failed due to the nodes belonging to different contexts." 

320 ) 

321 elif ( 

322 _py_major_minor >= (12, 8) 

323 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 

324 ): 

325 raise RuntimeError("One or more conditional handles are not associated with conditional builders.") 

326 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 

327 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") 

328 return graph 

329 

330 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): 

331 """Generates a DOT debug file for the graph builder. 

332 

333 Parameters 

334 ---------- 

335 path : str 

336 File path to use for writting debug DOT output 

337 options : :obj:`~_graph.GraphDebugPrintOptions`, optional 

338 Customizable dataclass for the debug print options. 

339 

340 """ 

341 if not self._building_ended: 

342 raise RuntimeError("Graph has not finished building.") 

343 flags = 0 

344 if options: 

345 if options.verbose: 

346 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 

347 if options.runtime_types: 

348 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 

349 if options.kernel_node_params: 

350 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 

351 if options.memcpy_node_params: 

352 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 

353 if options.memset_node_params: 

354 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 

355 if options.host_node_params: 

356 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 

357 if options.event_node_params: 

358 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 

359 if options.ext_semas_signal_node_params: 

360 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 

361 if options.ext_semas_wait_node_params: 

362 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 

363 if options.kernel_node_attributes: 

364 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 

365 if options.handles: 

366 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 

367 if options.mem_alloc_node_params: 

368 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 

369 if options.mem_free_node_params: 

370 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 

371 if options.batch_mem_op_node_params: 

372 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 

373 if options.extra_topo_info: 

374 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 

375 if options.conditional_node_params: 

376 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 

377 

378 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) 

379 

380 def split(self, count: int) -> tuple[GraphBuilder, ...]: 

381 """Splits the original graph builder into multiple graph builders. 

382 

383 The new builders inherit work dependencies from the original builder. 

384 The original builder is reused for the split and is returned first in the tuple. 

385 

386 Parameters 

387 ---------- 

388 count : int 

389 The number of graph builders to split the graph builder into. 

390 

391 Returns 

392 ------- 

393 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] 

394 A tuple of split graph builders. The first graph builder in the tuple 

395 is always the original graph builder. 

396 

397 """ 

398 if count < 2: 

399 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 

400 

401 event = self._mnff.stream.record() 

402 result = [self] 

403 for i in range(count - 1): 

404 stream = self._mnff.stream.device.create_stream() 

405 stream.wait(event) 

406 result.append( 

407 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) 

408 ) 

409 event.close() 

410 return result 

411 

412 @staticmethod 

413 def join(*graph_builders) -> GraphBuilder: 

414 """Joins multiple graph builders into a single graph builder. 

415 

416 The returned builder inherits work dependencies from the provided builders. 

417 

418 Parameters 

419 ---------- 

420 *graph_builders : :obj:`~_graph.GraphBuilder` 

421 The graph builders to join. 

422 

423 Returns 

424 ------- 

425 graph_builder : :obj:`~_graph.GraphBuilder` 

426 The newly joined graph builder. 

427 

428 """ 

429 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 

430 raise TypeError("All arguments must be GraphBuilder instances") 

431 if len(graph_builders) < 2: 

432 raise ValueError("Must join with at least two graph builders") 

433 

434 # Discover the root builder others should join 

435 root_idx = 0 

436 for i, builder in enumerate(graph_builders): 

437 if not builder.is_join_required: 

438 root_idx = i 

439 break 

440 

441 # Join all onto the root builder 

442 root_bdr = graph_builders[root_idx] 

443 for idx, builder in enumerate(graph_builders): 

444 if idx == root_idx: 

445 continue 

446 root_bdr.stream.wait(builder.stream) 

447 builder.close() 

448 

449 return root_bdr 

450 

451 def __cuda_stream__(self) -> tuple[int, int]: 

452 """Return an instance of a __cuda_stream__ protocol.""" 

453 return self.stream.__cuda_stream__() 

454 

455 def _get_conditional_context(self) -> driver.CUcontext: 

456 return self._mnff.stream.context._handle 

457 

458 def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle: 

459 """Creates a conditional handle for the graph builder. 

460 

461 Parameters 

462 ---------- 

463 default_value : int, optional 

464 The default value to assign to the conditional handle. 

465 

466 Returns 

467 ------- 

468 handle : driver.CUgraphConditionalHandle 

469 The newly created conditional handle. 

470 

471 """ 

472 if _driver_ver < 12030: 

473 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles") 

474 if _py_major_minor < (12, 3): 

475 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles") 

476 if default_value is not None: 

477 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 

478 else: 

479 default_value = 0 

480 flags = 0 

481 

482 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 

483 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 

484 raise RuntimeError("Cannot create a conditional handle when graph is not being built") 

485 

486 return handle_return( 

487 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 

488 ) 

489 

490 def _cond_with_params(self, node_params) -> GraphBuilder: 

491 # Get current capture info to ensure we're in a valid state 

492 status, _, graph, *deps_info, num_dependencies = handle_return( 

493 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) 

494 ) 

495 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 

496 raise RuntimeError("Cannot add conditional node when not actively capturing") 

497 

498 # Add the conditional node to the graph 

499 deps_info_update = [ 

500 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 

501 ] + [None] * (len(deps_info) - 1) 

502 

503 # Update the stream's capture dependencies 

504 handle_return( 

505 driver.cuStreamUpdateCaptureDependencies( 

506 self._mnff.stream.handle, 

507 *deps_info_update, # dependencies, edgeData 

508 1, # numDependencies 

509 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 

510 ) 

511 ) 

512 

513 # Create new graph builders for each condition 

514 return tuple( 

515 [ 

516 GraphBuilder._init( 

517 stream=self._mnff.stream.device.create_stream(), 

518 is_stream_owner=True, 

519 conditional_graph=node_params.conditional.phGraph_out[i], 

520 is_join_required=False, 

521 ) 

522 for i in range(node_params.conditional.size) 

523 ] 

524 ) 

525 

526 def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: 

527 """Adds an if condition branch and returns a new graph builder for it. 

528 

529 The resulting if graph will only execute the branch if the conditional 

530 handle evaluates to true at runtime. 

531 

532 The new builder inherits work dependencies from the original builder. 

533 

534 Parameters 

535 ---------- 

536 handle : driver.CUgraphConditionalHandle 

537 The handle to use for the if conditional. 

538 

539 Returns 

540 ------- 

541 graph_builder : :obj:`~_graph.GraphBuilder` 

542 The newly created conditional graph builder. 

543 

544 """ 

545 if _driver_ver < 12030: 

546 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if") 

547 if _py_major_minor < (12, 3): 

548 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if") 

549 node_params = driver.CUgraphNodeParams() 

550 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 

551 node_params.conditional.handle = handle 

552 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 

553 node_params.conditional.size = 1 

554 node_params.conditional.ctx = self._get_conditional_context() 

555 return self._cond_with_params(node_params)[0] 

556 

557 def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]: 

558 """Adds an if-else condition branch and returns new graph builders for both branches. 

559 

560 The resulting if graph will execute the branch if the conditional handle 

561 evaluates to true at runtime, otherwise the else branch will execute. 

562 

563 The new builders inherit work dependencies from the original builder. 

564 

565 Parameters 

566 ---------- 

567 handle : driver.CUgraphConditionalHandle 

568 The handle to use for the if-else conditional. 

569 

570 Returns 

571 ------- 

572 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, :obj:`~_graph.GraphBuilder`] 

573 A tuple of two new graph builders, one for the if branch and one for the else branch. 

574 

575 """ 

576 if _driver_ver < 12080: 

577 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else") 

578 if _py_major_minor < (12, 8): 

579 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else") 

580 node_params = driver.CUgraphNodeParams() 

581 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 

582 node_params.conditional.handle = handle 

583 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 

584 node_params.conditional.size = 2 

585 node_params.conditional.ctx = self._get_conditional_context() 

586 return self._cond_with_params(node_params) 

587 

588 def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]: 

589 """Adds a switch condition branch and returns new graph builders for all cases. 

590 

591 The resulting switch graph will execute the branch that matches the 

592 case index of the conditional handle at runtime. If no match is found, no branch 

593 will be executed. 

594 

595 The new builders inherit work dependencies from the original builder. 

596 

597 Parameters 

598 ---------- 

599 handle : driver.CUgraphConditionalHandle 

600 The handle to use for the switch conditional. 

601 count : int 

602 The number of cases to add to the switch conditional. 

603 

604 Returns 

605 ------- 

606 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] 

607 A tuple of new graph builders, one for each branch. 

608 

609 """ 

610 if _driver_ver < 12080: 

611 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch") 

612 if _py_major_minor < (12, 8): 

613 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch") 

614 node_params = driver.CUgraphNodeParams() 

615 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 

616 node_params.conditional.handle = handle 

617 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 

618 node_params.conditional.size = count 

619 node_params.conditional.ctx = self._get_conditional_context() 

620 return self._cond_with_params(node_params) 

621 

622 def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: 

623 """Adds a while loop and returns a new graph builder for it. 

624 

625 The resulting while loop graph will execute the branch repeatedly at runtime 

626 until the conditional handle evaluates to false. 

627 

628 The new builder inherits work dependencies from the original builder. 

629 

630 Parameters 

631 ---------- 

632 handle : driver.CUgraphConditionalHandle 

633 The handle to use for the while loop. 

634 

635 Returns 

636 ------- 

637 graph_builder : :obj:`~_graph.GraphBuilder` 

638 The newly created while loop graph builder. 

639 

640 """ 

641 if _driver_ver < 12030: 

642 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop") 

643 if _py_major_minor < (12, 3): 

644 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop") 

645 node_params = driver.CUgraphNodeParams() 

646 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 

647 node_params.conditional.handle = handle 

648 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 

649 node_params.conditional.size = 1 

650 node_params.conditional.ctx = self._get_conditional_context() 

651 return self._cond_with_params(node_params)[0] 

652 

653 def close(self): 

654 """Destroy the graph builder. 

655 

656 Closes the associated stream if we own it. Borrowed stream 

657 object will instead have their references released. 

658 

659 """ 

660 self._mnff.close() 

661 

662 def add_child(self, child_graph: GraphBuilder): 

663 """Adds the child :obj:`~_graph.GraphBuilder` builder into self. 

664 

665 The child graph builder will be added as a child node to the parent graph builder. 

666 

667 Parameters 

668 ---------- 

669 child_graph : :obj:`~_graph.GraphBuilder` 

670 The child graph builder. Must have finished building. 

671 """ 

672 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): 

673 raise NotImplementedError( 

674 f"Launching child graphs is not implemented for versions older than CUDA 12." 

675 f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}" 

676 ) 

677 

678 if not child_graph._building_ended: 

679 raise ValueError("Child graph has not finished building.") 

680 

681 if not self.is_building: 

682 raise ValueError("Parent graph is not being built.") 

683 

684 stream_handle = self._mnff.stream.handle 

685 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 

686 driver.cuStreamGetCaptureInfo(stream_handle) 

687 ) 

688 

689 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159 

690 # for rationale 

691 deps_info_trimmed = deps_info_out[:num_dependencies_out] 

692 deps_info_update = [ 

693 [ 

694 handle_return( 

695 driver.cuGraphAddChildGraphNode( 

696 graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph 

697 ) 

698 ) 

699 ] 

700 ] + [None] * (len(deps_info_out) - 1) 

701 handle_return( 

702 driver.cuStreamUpdateCaptureDependencies( 

703 stream_handle, 

704 *deps_info_update, # dependencies, edgeData 

705 1, 

706 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 

707 ) 

708 ) 

709 

710 

711class Graph: 

712 """Represents an executable graph. 

713 

714 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

715 them with a specified dependency tree. It speeds up the workflow by combining the 

716 driver activities associated with CUDA kernel launches and CUDA API calls. 

717 

718 Graphs must be built using a :obj:`~_graph.GraphBuilder` object. 

719 

720 """ 

721 

722 class _MembersNeededForFinalize: 

723 __slots__ = "graph" 

724 

725 def __init__(self, graph_obj, graph): 

726 self.graph = graph 

727 weakref.finalize(graph_obj, self.close) 

728 

729 def close(self): 

730 if self.graph: 

731 handle_return(driver.cuGraphExecDestroy(self.graph)) 

732 self.graph = None 

733 

734 __slots__ = ("__weakref__", "_mnff") 

735 

736 def __init__(self): 

737 raise RuntimeError("directly constructing a Graph instance is not supported") 

738 

739 @classmethod 

740 def _init(cls, graph): 

741 self = cls.__new__(cls) 

742 self._mnff = Graph._MembersNeededForFinalize(self, graph) 

743 return self 

744 

745 def close(self): 

746 """Destroy the graph.""" 

747 self._mnff.close() 

748 

749 def update(self, builder: GraphBuilder): 

750 """Update the graph using new build configuration from the builder. 

751 

752 The topology of the provided builder must be identical to this graph. 

753 

754 Parameters 

755 ---------- 

756 builder : :obj:`~_graph.GraphBuilder` 

757 The builder to update the graph with. 

758 

759 """ 

760 if not builder._building_ended: 

761 raise ValueError("Graph has not finished building.") 

762 

763 # Update the graph with the new nodes from the builder 

764 exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph)) 

765 if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS: 

766 raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}") 

767 

768 def upload(self, stream: Stream): 

769 """Uploads the graph in a stream. 

770 

771 Parameters 

772 ---------- 

773 stream : :obj:`~_stream.Stream` 

774 The stream in which to upload the graph 

775 

776 """ 

777 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 

778 

779 def launch(self, stream: Stream): 

780 """Launches the graph in a stream. 

781 

782 Parameters 

783 ---------- 

784 stream : :obj:`~_stream.Stream` 

785 The stream in which to launch the graph 

786 

787 """ 

788 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle))