Coverage for cuda / core / experimental / _graph.py: 88%
331 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7import weakref
8from dataclasses import dataclass
9from typing import TYPE_CHECKING
11if TYPE_CHECKING:
12 from cuda.core.experimental._stream import Stream
13from cuda.core.experimental._utils.cuda_utils import (
14 driver,
15 get_binding_version,
16 handle_return,
17)
19_inited = False
20_driver_ver = None
23def _lazy_init():
24 global _inited
25 if _inited:
26 return
28 global _py_major_minor, _driver_ver
29 # binding availability depends on cuda-python version
30 _py_major_minor = get_binding_version()
31 _driver_ver = handle_return(driver.cuDriverGetVersion())
32 _inited = True
35@dataclass
36class GraphDebugPrintOptions:
37 """Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()`
39 Attributes
40 ----------
41 verbose : bool
42 Output all debug data as if every debug flag is enabled (Default to False)
43 runtime_types : bool
44 Use CUDA Runtime structures for output (Default to False)
45 kernel_node_params : bool
46 Adds kernel parameter values to output (Default to False)
47 memcpy_node_params : bool
48 Adds memcpy parameter values to output (Default to False)
49 memset_node_params : bool
50 Adds memset parameter values to output (Default to False)
51 host_node_params : bool
52 Adds host parameter values to output (Default to False)
53 event_node_params : bool
54 Adds event parameter values to output (Default to False)
55 ext_semas_signal_node_params : bool
56 Adds external semaphore signal parameter values to output (Default to False)
57 ext_semas_wait_node_params : bool
58 Adds external semaphore wait parameter values to output (Default to False)
59 kernel_node_attributes : bool
60 Adds kernel node attributes to output (Default to False)
61 handles : bool
62 Adds node handles and every kernel function handle to output (Default to False)
63 mem_alloc_node_params : bool
64 Adds memory alloc parameter values to output (Default to False)
65 mem_free_node_params : bool
66 Adds memory free parameter values to output (Default to False)
67 batch_mem_op_node_params : bool
68 Adds batch mem op parameter values to output (Default to False)
69 extra_topo_info : bool
70 Adds edge numbering information (Default to False)
71 conditional_node_params : bool
72 Adds conditional node parameter values to output (Default to False)
74 """
76 verbose: bool = False
77 runtime_types: bool = False
78 kernel_node_params: bool = False
79 memcpy_node_params: bool = False
80 memset_node_params: bool = False
81 host_node_params: bool = False
82 event_node_params: bool = False
83 ext_semas_signal_node_params: bool = False
84 ext_semas_wait_node_params: bool = False
85 kernel_node_attributes: bool = False
86 handles: bool = False
87 mem_alloc_node_params: bool = False
88 mem_free_node_params: bool = False
89 batch_mem_op_node_params: bool = False
90 extra_topo_info: bool = False
91 conditional_node_params: bool = False
94@dataclass
95class GraphCompleteOptions:
96 """Customizable options for :obj:`_graph.GraphBuilder.complete()`
98 Attributes
99 ----------
100 auto_free_on_launch : bool, optional
101 Automatically free memory allocated in a graph before relaunching. (Default to False)
102 upload_stream : Stream, optional
103 Stream to use to automatically upload the graph after completion. (Default to None)
104 device_launch : bool, optional
105 Configure the graph to be launchable from the device. This flag can only
106 be used on platforms which support unified addressing. This flag cannot be
107 used in conjunction with auto_free_on_launch. (Default to False)
108 use_node_priority : bool, optional
109 Run the graph using the per-node priority attributes rather than the
110 priority of the stream it is launched into. (Default to False)
112 """
114 auto_free_on_launch: bool = False
115 upload_stream: Stream | None = None
116 device_launch: bool = False
117 use_node_priority: bool = False
120class GraphBuilder:
121 """Represents a graph under construction.
123 A graph groups a set of CUDA kernels and other CUDA operations together and executes
124 them with a specified dependency tree. It speeds up the workflow by combining the
125 driver activities associated with CUDA kernel launches and CUDA API calls.
127 Directly creating a :obj:`~_graph.GraphBuilder` is not supported due
128 to ambiguity. New graph builders should instead be created through a
129 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object.
131 """
133 class _MembersNeededForFinalize:
134 __slots__ = ("stream", "is_stream_owner", "graph", "conditional_graph", "is_join_required")
136 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required):
137 self.stream = stream_obj
138 self.is_stream_owner = is_stream_owner
139 self.graph = None
140 self.conditional_graph = conditional_graph
141 self.is_join_required = is_join_required
142 weakref.finalize(graph_builder_obj, self.close)
144 def close(self):
145 if self.stream:
146 if not self.is_join_required:
147 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0]
148 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE:
149 # Note how this condition only occures for the primary graph builder
150 # This is because calling cuStreamEndCapture streams that were split off of the primary
151 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED.
152 # Therefore, it is currently a requirement that users join all split graph builders
153 # before a graph builder can be clearly destroyed.
154 handle_return(driver.cuStreamEndCapture(self.stream.handle))
155 if self.is_stream_owner:
156 self.stream.close()
157 self.stream = None
158 if self.graph:
159 handle_return(driver.cuGraphDestroy(self.graph))
160 self.graph = None
161 self.conditional_graph = None
163 __slots__ = ("__weakref__", "_mnff", "_building_ended")
165 def __init__(self):
166 raise NotImplementedError(
167 "directly creating a Graph object can be ambiguous. Please either "
168 "call Device.create_graph_builder() or stream.create_graph_builder()"
169 )
171 @classmethod
172 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False):
173 self = cls.__new__(cls)
174 _lazy_init()
175 self._mnff = GraphBuilder._MembersNeededForFinalize(
176 self, stream, is_stream_owner, conditional_graph, is_join_required
177 )
179 self._building_ended = False
180 return self
182 @property
183 def stream(self) -> Stream:
184 """Returns the stream associated with the graph builder."""
185 return self._mnff.stream
187 @property
188 def is_join_required(self) -> bool:
189 """Returns True if this graph builder must be joined before building is ended."""
190 return self._mnff.is_join_required
192 def begin_building(self, mode="relaxed") -> GraphBuilder:
193 """Begins the building process.
195 Build `mode` for controlling interaction with other API calls must be one of the following:
197 - `global` : Prohibit potentially unsafe operations across all streams in the process.
198 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread.
199 - `relaxed` : The local thread is not prohibited from potentially unsafe operations.
201 Parameters
202 ----------
203 mode : str, optional
204 Build mode to control the interaction with other API calls that are porentially unsafe.
205 Default set to use relaxed.
207 """
208 if self._building_ended:
209 raise RuntimeError("Cannot resume building after building has ended.")
210 if mode not in ("global", "thread_local", "relaxed"):
211 raise ValueError(f"Unsupported build mode: {mode}")
212 if mode == "global":
213 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL
214 elif mode == "thread_local":
215 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL
216 elif mode == "relaxed":
217 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED
218 else:
219 raise ValueError(f"Unsupported build mode: {mode}")
221 if self._mnff.conditional_graph:
222 handle_return(
223 driver.cuStreamBeginCaptureToGraph(
224 self._mnff.stream.handle,
225 self._mnff.conditional_graph,
226 None, # dependencies
227 None, # dependencyData
228 0, # numDependencies
229 capture_mode,
230 )
231 )
232 else:
233 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode))
234 return self
236 @property
237 def is_building(self) -> bool:
238 """Returns True if the graph builder is currently building."""
239 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0]
240 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE:
241 return False
242 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE:
243 return True
244 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED:
245 raise RuntimeError(
246 "Build process encountered an error and has been invalidated. Build process must now be ended."
247 )
248 else:
249 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}")
251 def end_building(self) -> GraphBuilder:
252 """Ends the building process."""
253 if not self.is_building:
254 raise RuntimeError("Graph builder is not building.")
255 if self._mnff.conditional_graph:
256 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle))
257 else:
258 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle))
260 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to
261 # resume the build process after the first call to end_building()
262 self._building_ended = True
263 return self
265 def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
266 """Completes the graph builder and returns the built :obj:`~_graph.Graph` object.
268 Parameters
269 ----------
270 options : :obj:`~_graph.GraphCompleteOptions`, optional
271 Customizable dataclass for the graph builder completion options.
273 Returns
274 -------
275 graph : :obj:`~_graph.Graph`
276 The newly built graph.
278 """
279 if not self._building_ended:
280 raise RuntimeError("Graph has not finished building.")
282 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)):
283 flags = 0
284 if options:
285 if options.auto_free_on_launch:
286 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
287 if options.use_node_priority:
288 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
289 return Graph._init(handle_return(driver.cuGraphInstantiateWithFlags(self._mnff.graph, flags)))
291 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS()
292 if options:
293 flags = 0
294 if options.auto_free_on_launch:
295 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
296 if options.upload_stream:
297 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
298 params.hUploadStream = options.upload_stream.handle
299 if options.device_launch:
300 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
301 if options.use_node_priority:
302 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
303 params.flags = flags
305 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(self._mnff.graph, params)))
306 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR:
307 # NOTE: Should never get here since the handle_return should have caught this case
308 raise RuntimeError(
309 "Instantiation failed for an unexpected reason which is described in the return value of the function."
310 )
311 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE:
312 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
313 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED:
314 raise RuntimeError(
315 "Instantiation for device launch failed because the graph contained an unsupported operation."
316 )
317 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
318 raise RuntimeError(
319 "Instantiation for device launch failed due to the nodes belonging to different contexts."
320 )
321 elif (
322 _py_major_minor >= (12, 8)
323 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
324 ):
325 raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
326 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
327 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
328 return graph
330 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
331 """Generates a DOT debug file for the graph builder.
333 Parameters
334 ----------
335 path : str
336 File path to use for writting debug DOT output
337 options : :obj:`~_graph.GraphDebugPrintOptions`, optional
338 Customizable dataclass for the debug print options.
340 """
341 if not self._building_ended:
342 raise RuntimeError("Graph has not finished building.")
343 flags = 0
344 if options:
345 if options.verbose:
346 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE
347 if options.runtime_types:
348 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES
349 if options.kernel_node_params:
350 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS
351 if options.memcpy_node_params:
352 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS
353 if options.memset_node_params:
354 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS
355 if options.host_node_params:
356 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS
357 if options.event_node_params:
358 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS
359 if options.ext_semas_signal_node_params:
360 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS
361 if options.ext_semas_wait_node_params:
362 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS
363 if options.kernel_node_attributes:
364 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES
365 if options.handles:
366 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES
367 if options.mem_alloc_node_params:
368 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS
369 if options.mem_free_node_params:
370 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS
371 if options.batch_mem_op_node_params:
372 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS
373 if options.extra_topo_info:
374 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO
375 if options.conditional_node_params:
376 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS
378 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags))
380 def split(self, count: int) -> tuple[GraphBuilder, ...]:
381 """Splits the original graph builder into multiple graph builders.
383 The new builders inherit work dependencies from the original builder.
384 The original builder is reused for the split and is returned first in the tuple.
386 Parameters
387 ----------
388 count : int
389 The number of graph builders to split the graph builder into.
391 Returns
392 -------
393 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...]
394 A tuple of split graph builders. The first graph builder in the tuple
395 is always the original graph builder.
397 """
398 if count < 2:
399 raise ValueError(f"Invalid split count: expecting >= 2, got {count}")
401 event = self._mnff.stream.record()
402 result = [self]
403 for i in range(count - 1):
404 stream = self._mnff.stream.device.create_stream()
405 stream.wait(event)
406 result.append(
407 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True)
408 )
409 event.close()
410 return result
412 @staticmethod
413 def join(*graph_builders) -> GraphBuilder:
414 """Joins multiple graph builders into a single graph builder.
416 The returned builder inherits work dependencies from the provided builders.
418 Parameters
419 ----------
420 *graph_builders : :obj:`~_graph.GraphBuilder`
421 The graph builders to join.
423 Returns
424 -------
425 graph_builder : :obj:`~_graph.GraphBuilder`
426 The newly joined graph builder.
428 """
429 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders):
430 raise TypeError("All arguments must be GraphBuilder instances")
431 if len(graph_builders) < 2:
432 raise ValueError("Must join with at least two graph builders")
434 # Discover the root builder others should join
435 root_idx = 0
436 for i, builder in enumerate(graph_builders):
437 if not builder.is_join_required:
438 root_idx = i
439 break
441 # Join all onto the root builder
442 root_bdr = graph_builders[root_idx]
443 for idx, builder in enumerate(graph_builders):
444 if idx == root_idx:
445 continue
446 root_bdr.stream.wait(builder.stream)
447 builder.close()
449 return root_bdr
451 def __cuda_stream__(self) -> tuple[int, int]:
452 """Return an instance of a __cuda_stream__ protocol."""
453 return self.stream.__cuda_stream__()
455 def _get_conditional_context(self) -> driver.CUcontext:
456 return self._mnff.stream.context._handle
458 def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle:
459 """Creates a conditional handle for the graph builder.
461 Parameters
462 ----------
463 default_value : int, optional
464 The default value to assign to the conditional handle.
466 Returns
467 -------
468 handle : driver.CUgraphConditionalHandle
469 The newly created conditional handle.
471 """
472 if _driver_ver < 12030:
473 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles")
474 if _py_major_minor < (12, 3):
475 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles")
476 if default_value is not None:
477 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT
478 else:
479 default_value = 0
480 flags = 0
482 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))
483 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE:
484 raise RuntimeError("Cannot create a conditional handle when graph is not being built")
486 return handle_return(
487 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags)
488 )
490 def _cond_with_params(self, node_params) -> GraphBuilder:
491 # Get current capture info to ensure we're in a valid state
492 status, _, graph, *deps_info, num_dependencies = handle_return(
493 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)
494 )
495 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE:
496 raise RuntimeError("Cannot add conditional node when not actively capturing")
498 # Add the conditional node to the graph
499 deps_info_update = [
500 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))]
501 ] + [None] * (len(deps_info) - 1)
503 # Update the stream's capture dependencies
504 handle_return(
505 driver.cuStreamUpdateCaptureDependencies(
506 self._mnff.stream.handle,
507 *deps_info_update, # dependencies, edgeData
508 1, # numDependencies
509 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES,
510 )
511 )
513 # Create new graph builders for each condition
514 return tuple(
515 [
516 GraphBuilder._init(
517 stream=self._mnff.stream.device.create_stream(),
518 is_stream_owner=True,
519 conditional_graph=node_params.conditional.phGraph_out[i],
520 is_join_required=False,
521 )
522 for i in range(node_params.conditional.size)
523 ]
524 )
526 def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder:
527 """Adds an if condition branch and returns a new graph builder for it.
529 The resulting if graph will only execute the branch if the conditional
530 handle evaluates to true at runtime.
532 The new builder inherits work dependencies from the original builder.
534 Parameters
535 ----------
536 handle : driver.CUgraphConditionalHandle
537 The handle to use for the if conditional.
539 Returns
540 -------
541 graph_builder : :obj:`~_graph.GraphBuilder`
542 The newly created conditional graph builder.
544 """
545 if _driver_ver < 12030:
546 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if")
547 if _py_major_minor < (12, 3):
548 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if")
549 node_params = driver.CUgraphNodeParams()
550 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
551 node_params.conditional.handle = handle
552 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF
553 node_params.conditional.size = 1
554 node_params.conditional.ctx = self._get_conditional_context()
555 return self._cond_with_params(node_params)[0]
557 def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]:
558 """Adds an if-else condition branch and returns new graph builders for both branches.
560 The resulting if graph will execute the branch if the conditional handle
561 evaluates to true at runtime, otherwise the else branch will execute.
563 The new builders inherit work dependencies from the original builder.
565 Parameters
566 ----------
567 handle : driver.CUgraphConditionalHandle
568 The handle to use for the if-else conditional.
570 Returns
571 -------
572 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, :obj:`~_graph.GraphBuilder`]
573 A tuple of two new graph builders, one for the if branch and one for the else branch.
575 """
576 if _driver_ver < 12080:
577 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else")
578 if _py_major_minor < (12, 8):
579 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else")
580 node_params = driver.CUgraphNodeParams()
581 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
582 node_params.conditional.handle = handle
583 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF
584 node_params.conditional.size = 2
585 node_params.conditional.ctx = self._get_conditional_context()
586 return self._cond_with_params(node_params)
588 def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]:
589 """Adds a switch condition branch and returns new graph builders for all cases.
591 The resulting switch graph will execute the branch that matches the
592 case index of the conditional handle at runtime. If no match is found, no branch
593 will be executed.
595 The new builders inherit work dependencies from the original builder.
597 Parameters
598 ----------
599 handle : driver.CUgraphConditionalHandle
600 The handle to use for the switch conditional.
601 count : int
602 The number of cases to add to the switch conditional.
604 Returns
605 -------
606 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...]
607 A tuple of new graph builders, one for each branch.
609 """
610 if _driver_ver < 12080:
611 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch")
612 if _py_major_minor < (12, 8):
613 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch")
614 node_params = driver.CUgraphNodeParams()
615 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
616 node_params.conditional.handle = handle
617 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH
618 node_params.conditional.size = count
619 node_params.conditional.ctx = self._get_conditional_context()
620 return self._cond_with_params(node_params)
622 def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder:
623 """Adds a while loop and returns a new graph builder for it.
625 The resulting while loop graph will execute the branch repeatedly at runtime
626 until the conditional handle evaluates to false.
628 The new builder inherits work dependencies from the original builder.
630 Parameters
631 ----------
632 handle : driver.CUgraphConditionalHandle
633 The handle to use for the while loop.
635 Returns
636 -------
637 graph_builder : :obj:`~_graph.GraphBuilder`
638 The newly created while loop graph builder.
640 """
641 if _driver_ver < 12030:
642 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop")
643 if _py_major_minor < (12, 3):
644 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop")
645 node_params = driver.CUgraphNodeParams()
646 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
647 node_params.conditional.handle = handle
648 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE
649 node_params.conditional.size = 1
650 node_params.conditional.ctx = self._get_conditional_context()
651 return self._cond_with_params(node_params)[0]
653 def close(self):
654 """Destroy the graph builder.
656 Closes the associated stream if we own it. Borrowed stream
657 object will instead have their references released.
659 """
660 self._mnff.close()
662 def add_child(self, child_graph: GraphBuilder):
663 """Adds the child :obj:`~_graph.GraphBuilder` builder into self.
665 The child graph builder will be added as a child node to the parent graph builder.
667 Parameters
668 ----------
669 child_graph : :obj:`~_graph.GraphBuilder`
670 The child graph builder. Must have finished building.
671 """
672 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)):
673 raise NotImplementedError(
674 f"Launching child graphs is not implemented for versions older than CUDA 12."
675 f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}"
676 )
678 if not child_graph._building_ended:
679 raise ValueError("Child graph has not finished building.")
681 if not self.is_building:
682 raise ValueError("Parent graph is not being built.")
684 stream_handle = self._mnff.stream.handle
685 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return(
686 driver.cuStreamGetCaptureInfo(stream_handle)
687 )
689 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159
690 # for rationale
691 deps_info_trimmed = deps_info_out[:num_dependencies_out]
692 deps_info_update = [
693 [
694 handle_return(
695 driver.cuGraphAddChildGraphNode(
696 graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph
697 )
698 )
699 ]
700 ] + [None] * (len(deps_info_out) - 1)
701 handle_return(
702 driver.cuStreamUpdateCaptureDependencies(
703 stream_handle,
704 *deps_info_update, # dependencies, edgeData
705 1,
706 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES,
707 )
708 )
711class Graph:
712 """Represents an executable graph.
714 A graph groups a set of CUDA kernels and other CUDA operations together and executes
715 them with a specified dependency tree. It speeds up the workflow by combining the
716 driver activities associated with CUDA kernel launches and CUDA API calls.
718 Graphs must be built using a :obj:`~_graph.GraphBuilder` object.
720 """
722 class _MembersNeededForFinalize:
723 __slots__ = "graph"
725 def __init__(self, graph_obj, graph):
726 self.graph = graph
727 weakref.finalize(graph_obj, self.close)
729 def close(self):
730 if self.graph:
731 handle_return(driver.cuGraphExecDestroy(self.graph))
732 self.graph = None
734 __slots__ = ("__weakref__", "_mnff")
736 def __init__(self):
737 raise RuntimeError("directly constructing a Graph instance is not supported")
739 @classmethod
740 def _init(cls, graph):
741 self = cls.__new__(cls)
742 self._mnff = Graph._MembersNeededForFinalize(self, graph)
743 return self
745 def close(self):
746 """Destroy the graph."""
747 self._mnff.close()
749 def update(self, builder: GraphBuilder):
750 """Update the graph using new build configuration from the builder.
752 The topology of the provided builder must be identical to this graph.
754 Parameters
755 ----------
756 builder : :obj:`~_graph.GraphBuilder`
757 The builder to update the graph with.
759 """
760 if not builder._building_ended:
761 raise ValueError("Graph has not finished building.")
763 # Update the graph with the new nodes from the builder
764 exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph))
765 if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS:
766 raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}")
768 def upload(self, stream: Stream):
769 """Uploads the graph in a stream.
771 Parameters
772 ----------
773 stream : :obj:`~_stream.Stream`
774 The stream in which to upload the graph
776 """
777 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle))
779 def launch(self, stream: Stream):
780 """Launches the graph in a stream.
782 Parameters
783 ----------
784 stream : :obj:`~_stream.Stream`
785 The stream in which to launch the graph
787 """
788 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle))