Coverage for cuda / core / _graph.py: 87.43%
334 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-08 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7import weakref
8from dataclasses import dataclass
9from typing import TYPE_CHECKING
11if TYPE_CHECKING:
12 from cuda.core._stream import Stream
14from cuda.core._utils.cuda_utils import (
15 driver,
16 get_binding_version,
17 handle_return,
18)
20_inited = False
21_driver_ver = None
24def _lazy_init():
25 global _inited
26 if _inited: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
27 return 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
29 global _py_major_minor, _driver_ver
30 # binding availability depends on cuda-python version
31 _py_major_minor = get_binding_version() 1D
32 _driver_ver = handle_return(driver.cuDriverGetVersion()) 1D
33 _inited = True 1D
36@dataclass
37class GraphDebugPrintOptions:
38 """Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()`
40 Attributes
41 ----------
42 verbose : bool
43 Output all debug data as if every debug flag is enabled (Default to False)
44 runtime_types : bool
45 Use CUDA Runtime structures for output (Default to False)
46 kernel_node_params : bool
47 Adds kernel parameter values to output (Default to False)
48 memcpy_node_params : bool
49 Adds memcpy parameter values to output (Default to False)
50 memset_node_params : bool
51 Adds memset parameter values to output (Default to False)
52 host_node_params : bool
53 Adds host parameter values to output (Default to False)
54 event_node_params : bool
55 Adds event parameter values to output (Default to False)
56 ext_semas_signal_node_params : bool
57 Adds external semaphore signal parameter values to output (Default to False)
58 ext_semas_wait_node_params : bool
59 Adds external semaphore wait parameter values to output (Default to False)
60 kernel_node_attributes : bool
61 Adds kernel node attributes to output (Default to False)
62 handles : bool
63 Adds node handles and every kernel function handle to output (Default to False)
64 mem_alloc_node_params : bool
65 Adds memory alloc parameter values to output (Default to False)
66 mem_free_node_params : bool
67 Adds memory free parameter values to output (Default to False)
68 batch_mem_op_node_params : bool
69 Adds batch mem op parameter values to output (Default to False)
70 extra_topo_info : bool
71 Adds edge numbering information (Default to False)
72 conditional_node_params : bool
73 Adds conditional node parameter values to output (Default to False)
75 """
77 verbose: bool = False
78 runtime_types: bool = False
79 kernel_node_params: bool = False
80 memcpy_node_params: bool = False
81 memset_node_params: bool = False
82 host_node_params: bool = False
83 event_node_params: bool = False
84 ext_semas_signal_node_params: bool = False
85 ext_semas_wait_node_params: bool = False
86 kernel_node_attributes: bool = False
87 handles: bool = False
88 mem_alloc_node_params: bool = False
89 mem_free_node_params: bool = False
90 batch_mem_op_node_params: bool = False
91 extra_topo_info: bool = False
92 conditional_node_params: bool = False
95@dataclass
96class GraphCompleteOptions:
97 """Customizable options for :obj:`_graph.GraphBuilder.complete()`
99 Attributes
100 ----------
101 auto_free_on_launch : bool, optional
102 Automatically free memory allocated in a graph before relaunching. (Default to False)
103 upload_stream : Stream, optional
104 Stream to use to automatically upload the graph after completion. (Default to None)
105 device_launch : bool, optional
106 Configure the graph to be launchable from the device. This flag can only
107 be used on platforms which support unified addressing. This flag cannot be
108 used in conjunction with auto_free_on_launch. (Default to False)
109 use_node_priority : bool, optional
110 Run the graph using the per-node priority attributes rather than the
111 priority of the stream it is launched into. (Default to False)
113 """
115 auto_free_on_launch: bool = False
116 upload_stream: Stream | None = None
117 device_launch: bool = False
118 use_node_priority: bool = False
121class GraphBuilder:
122 """Represents a graph under construction.
124 A graph groups a set of CUDA kernels and other CUDA operations together and executes
125 them with a specified dependency tree. It speeds up the workflow by combining the
126 driver activities associated with CUDA kernel launches and CUDA API calls.
128 Directly creating a :obj:`~_graph.GraphBuilder` is not supported due
129 to ambiguity. New graph builders should instead be created through a
130 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object.
132 """
134 class _MembersNeededForFinalize:
135 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream")
137 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required):
138 self.stream = stream_obj 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
139 self.is_stream_owner = is_stream_owner 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
140 self.graph = None 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
141 self.conditional_graph = conditional_graph 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
142 self.is_join_required = is_join_required 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
143 weakref.finalize(graph_builder_obj, self.close) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
145 def close(self):
146 if self.stream: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
147 if not self.is_join_required: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
148 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
149 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
150 # Note how this condition only occures for the primary graph builder
151 # This is because calling cuStreamEndCapture streams that were split off of the primary
152 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED.
153 # Therefore, it is currently a requirement that users join all split graph builders
154 # before a graph builder can be clearly destroyed.
155 handle_return(driver.cuStreamEndCapture(self.stream.handle))
156 if self.is_stream_owner: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
157 self.stream.close() 1DIwOv5xJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
158 self.stream = None 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
159 if self.graph: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
160 handle_return(driver.cuGraphDestroy(self.graph)) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
161 self.graph = None 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
162 self.conditional_graph = None 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
164 __slots__ = ("__weakref__", "_building_ended", "_mnff")
166 def __init__(self):
167 raise NotImplementedError(
168 "directly creating a Graph object can be ambiguous. Please either "
169 "call Device.create_graph_builder() or stream.create_graph_builder()"
170 )
172 @classmethod
173 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False):
174 self = cls.__new__(cls) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
175 _lazy_init() 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
176 self._mnff = GraphBuilder._MembersNeededForFinalize( 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
177 self, stream, is_stream_owner, conditional_graph, is_join_required
178 )
180 self._building_ended = False 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
181 return self 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
183 @property
184 def stream(self) -> Stream:
185 """Returns the stream associated with the graph builder."""
186 return self._mnff.stream 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
188 @property
189 def is_join_required(self) -> bool:
190 """Returns True if this graph builder must be joined before building is ended."""
191 return self._mnff.is_join_required 1vxbcdefghijklmnopqrstua
193 def begin_building(self, mode="relaxed") -> GraphBuilder:
194 """Begins the building process.
196 Build `mode` for controlling interaction with other API calls must be one of the following:
198 - `global` : Prohibit potentially unsafe operations across all streams in the process.
199 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread.
200 - `relaxed` : The local thread is not prohibited from potentially unsafe operations.
202 Parameters
203 ----------
204 mode : str, optional
205 Build mode to control the interaction with other API calls that are porentially unsafe.
206 Default set to use relaxed.
208 """
209 if self._building_ended: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
210 raise RuntimeError("Cannot resume building after building has ended.") 1P
211 if mode not in ("global", "thread_local", "relaxed"): 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
212 raise ValueError(f"Unsupported build mode: {mode}") 14
213 if mode == "global": 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
214 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 123WXHYZ4
215 elif mode == "thread_local": 1DIwOv5xPJU0V1KQLRFGMSNTbcdefghijklmnopqrstuyzAB4Ea
216 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 101QRGST4
217 elif mode == "relaxed": 1DIwOv5xPJUVKLFMNbcdefghijklmnopqrstuyzAB4Ea
218 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1DIwOv5xPJUVKLFMNbcdefghijklmnopqrstuyzAB4Ea
219 else:
220 raise ValueError(f"Unsupported build mode: {mode}")
222 if self._mnff.conditional_graph: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
223 handle_return( 1wbcdefghijklmnopqrstuyzABa
224 driver.cuStreamBeginCaptureToGraph(
225 self._mnff.stream.handle,
226 self._mnff.conditional_graph,
227 None, # dependencies
228 None, # dependencyData
229 0, # numDependencies
230 capture_mode,
231 )
232 )
233 else:
234 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
235 return self 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
237 @property
238 def is_building(self) -> bool:
239 """Returns True if the graph builder is currently building."""
240 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
241 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
242 return False 15
243 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
244 return True 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
245 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED:
246 raise RuntimeError(
247 "Build process encountered an error and has been invalidated. Build process must now be ended."
248 )
249 else:
250 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}")
252 def end_building(self) -> GraphBuilder:
253 """Ends the building process."""
254 if not self.is_building: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
255 raise RuntimeError("Graph builder is not building.")
256 if self._mnff.conditional_graph: 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
257 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1wbcdefghijklmnopqrstuyzABa
258 else:
259 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
261 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to
262 # resume the build process after the first call to end_building()
263 self._building_ended = True 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
264 return self 1DIwOv5xPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB4Ea
266 def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
267 """Completes the graph builder and returns the built :obj:`~_graph.Graph` object.
269 Parameters
270 ----------
271 options : :obj:`~_graph.GraphCompleteOptions`, optional
272 Customizable dataclass for the graph builder completion options.
274 Returns
275 -------
276 graph : :obj:`~_graph.Graph`
277 The newly built graph.
279 """
280 if not self._building_ended: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
281 raise RuntimeError("Graph has not finished building.") 1O
283 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
284 flags = 0
285 if options:
286 if options.auto_free_on_launch:
287 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
288 if options.use_node_priority:
289 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
290 return Graph._init(handle_return(driver.cuGraphInstantiateWithFlags(self._mnff.graph, flags)))
292 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
293 if options: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
294 flags = 0 1HFGE
295 if options.auto_free_on_launch: 1HFGE
296 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1HFGE
297 if options.upload_stream: 1HFGE
298 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1E
299 params.hUploadStream = options.upload_stream.handle 1E
300 if options.device_launch: 1HFGE
301 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1E
302 if options.use_node_priority: 1HFGE
303 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 1E
304 params.flags = flags 1HFGE
306 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(self._mnff.graph, params))) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
307 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
308 # NOTE: Should never get here since the handle_return should have caught this case
309 raise RuntimeError(
310 "Instantiation failed for an unexpected reason which is described in the return value of the function."
311 )
312 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
313 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
314 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
315 raise RuntimeError(
316 "Instantiation for device launch failed because the graph contained an unsupported operation."
317 )
318 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
319 raise RuntimeError(
320 "Instantiation for device launch failed due to the nodes belonging to different contexts."
321 )
322 elif (
323 _py_major_minor >= (12, 8)
324 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
325 ):
326 raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
327 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
328 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
329 return graph 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
331 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
332 """Generates a DOT debug file for the graph builder.
334 Parameters
335 ----------
336 path : str
337 File path to use for writting debug DOT output
338 options : :obj:`~_graph.GraphDebugPrintOptions`, optional
339 Customizable dataclass for the debug print options.
341 """
342 if not self._building_ended: 1a
343 raise RuntimeError("Graph has not finished building.")
344 flags = 0 1a
345 if options: 1a
346 if options.verbose: 1a
347 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 1a
348 if options.runtime_types: 1a
349 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a
350 if options.kernel_node_params: 1a
351 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a
352 if options.memcpy_node_params: 1a
353 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a
354 if options.memset_node_params: 1a
355 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a
356 if options.host_node_params: 1a
357 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a
358 if options.event_node_params: 1a
359 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a
360 if options.ext_semas_signal_node_params: 1a
361 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a
362 if options.ext_semas_wait_node_params: 1a
363 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a
364 if options.kernel_node_attributes: 1a
365 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a
366 if options.handles: 1a
367 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 1a
368 if options.mem_alloc_node_params: 1a
369 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a
370 if options.mem_free_node_params: 1a
371 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a
372 if options.batch_mem_op_node_params: 1a
373 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a
374 if options.extra_topo_info: 1a
375 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a
376 if options.conditional_node_params: 1a
377 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a
379 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) 1a
381 def split(self, count: int) -> tuple[GraphBuilder, ...]:
382 """Splits the original graph builder into multiple graph builders.
384 The new builders inherit work dependencies from the original builder.
385 The original builder is reused for the split and is returned first in the tuple.
387 Parameters
388 ----------
389 count : int
390 The number of graph builders to split the graph builder into.
392 Returns
393 -------
394 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...]
395 A tuple of split graph builders. The first graph builder in the tuple
396 is always the original graph builder.
398 """
399 if count < 2: 1vxbcdefghijklmnopqrstua
400 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1v
402 event = self._mnff.stream.record() 1vxbcdefghijklmnopqrstua
403 result = [self] 1vxbcdefghijklmnopqrstua
404 for i in range(count - 1): 1vxbcdefghijklmnopqrstua
405 stream = self._mnff.stream.device.create_stream() 1vxbcdefghijklmnopqrstua
406 stream.wait(event) 1vxbcdefghijklmnopqrstua
407 result.append( 1vxbcdefghijklmnopqrstua
408 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True)
409 )
410 event.close() 1vxbcdefghijklmnopqrstua
411 return result 1vxbcdefghijklmnopqrstua
413 @staticmethod
414 def join(*graph_builders) -> GraphBuilder:
415 """Joins multiple graph builders into a single graph builder.
417 The returned builder inherits work dependencies from the provided builders.
419 Parameters
420 ----------
421 *graph_builders : :obj:`~_graph.GraphBuilder`
422 The graph builders to join.
424 Returns
425 -------
426 graph_builder : :obj:`~_graph.GraphBuilder`
427 The newly joined graph builder.
429 """
430 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1vxbcdefghijklmnopqrstua
431 raise TypeError("All arguments must be GraphBuilder instances")
432 if len(graph_builders) < 2: 1vxbcdefghijklmnopqrstua
433 raise ValueError("Must join with at least two graph builders") 1v
435 # Discover the root builder others should join
436 root_idx = 0 1vxbcdefghijklmnopqrstua
437 for i, builder in enumerate(graph_builders): 1vxbcdefghijklmnopqrstua
438 if not builder.is_join_required: 1vxbcdefghijklmnopqrstua
439 root_idx = i 1vxbcdefghijklmnopqrstua
440 break 1vxbcdefghijklmnopqrstua
442 # Join all onto the root builder
443 root_bdr = graph_builders[root_idx] 1vxbcdefghijklmnopqrstua
444 for idx, builder in enumerate(graph_builders): 1vxbcdefghijklmnopqrstua
445 if idx == root_idx: 1vxbcdefghijklmnopqrstua
446 continue 1vxbcdefghijklmnopqrstua
447 root_bdr.stream.wait(builder.stream) 1vxbcdefghijklmnopqrstua
448 builder.close() 1vxbcdefghijklmnopqrstua
450 return root_bdr 1vxbcdefghijklmnopqrstua
452 def __cuda_stream__(self) -> tuple[int, int]:
453 """Return an instance of a __cuda_stream__ protocol."""
454 return self.stream.__cuda_stream__()
456 def _get_conditional_context(self) -> driver.CUcontext:
457 return self._mnff.stream.context.handle 1wbcdefghijklmnopqrstuyzABa
459 def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle:
460 """Creates a conditional handle for the graph builder.
462 Parameters
463 ----------
464 default_value : int, optional
465 The default value to assign to the conditional handle.
467 Returns
468 -------
469 handle : driver.CUgraphConditionalHandle
470 The newly created conditional handle.
472 """
473 if _driver_ver < 12030: 1wbcdefghijklmnopqrstuyzABa
474 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles")
475 if _py_major_minor < (12, 3): 1wbcdefghijklmnopqrstuyzABa
476 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles")
477 if default_value is not None: 1wbcdefghijklmnopqrstuyzABa
478 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wyzAB
479 else:
480 default_value = 0 1bcdefghijklmnopqrstua
481 flags = 0 1bcdefghijklmnopqrstua
483 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1wbcdefghijklmnopqrstuyzABa
484 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1wbcdefghijklmnopqrstuyzABa
485 raise RuntimeError("Cannot create a conditional handle when graph is not being built")
487 return handle_return( 1wbcdefghijklmnopqrstuyzABa
488 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags)
489 )
491 def _cond_with_params(self, node_params) -> GraphBuilder:
492 # Get current capture info to ensure we're in a valid state
493 status, _, graph, *deps_info, num_dependencies = handle_return( 1wbcdefghijklmnopqrstuyzABa
494 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)
495 )
496 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1wbcdefghijklmnopqrstuyzABa
497 raise RuntimeError("Cannot add conditional node when not actively capturing")
499 # Add the conditional node to the graph
500 deps_info_update = [ 1wbcdefghijklmnopqrstuyzABa
501 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))]
502 ] + [None] * (len(deps_info) - 1)
504 # Update the stream's capture dependencies
505 handle_return( 1wbcdefghijklmnopqrstuyzABa
506 driver.cuStreamUpdateCaptureDependencies(
507 self._mnff.stream.handle,
508 *deps_info_update, # dependencies, edgeData
509 1, # numDependencies
510 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES,
511 )
512 )
514 # Create new graph builders for each condition
515 return tuple( 1wbcdefghijklmnopqrstuyzABa
516 [
517 GraphBuilder._init(
518 stream=self._mnff.stream.device.create_stream(),
519 is_stream_owner=True,
520 conditional_graph=node_params.conditional.phGraph_out[i],
521 is_join_required=False,
522 )
523 for i in range(node_params.conditional.size)
524 ]
525 )
527 def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder:
528 """Adds an if condition branch and returns a new graph builder for it.
530 The resulting if graph will only execute the branch if the conditional
531 handle evaluates to true at runtime.
533 The new builder inherits work dependencies from the original builder.
535 Parameters
536 ----------
537 handle : driver.CUgraphConditionalHandle
538 The handle to use for the if conditional.
540 Returns
541 -------
542 graph_builder : :obj:`~_graph.GraphBuilder`
543 The newly created conditional graph builder.
545 """
546 if _driver_ver < 12030: 1bcdefghia
547 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if")
548 if _py_major_minor < (12, 3): 1bcdefghia
549 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if")
550 node_params = driver.CUgraphNodeParams() 1bcdefghia
551 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia
552 node_params.conditional.handle = handle 1bcdefghia
553 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia
554 node_params.conditional.size = 1 1bcdefghia
555 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia
556 return self._cond_with_params(node_params)[0] 1bcdefghia
558 def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]:
559 """Adds an if-else condition branch and returns new graph builders for both branches.
561 The resulting if graph will execute the branch if the conditional handle
562 evaluates to true at runtime, otherwise the else branch will execute.
564 The new builders inherit work dependencies from the original builder.
566 Parameters
567 ----------
568 handle : driver.CUgraphConditionalHandle
569 The handle to use for the if-else conditional.
571 Returns
572 -------
573 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, :obj:`~_graph.GraphBuilder`]
574 A tuple of two new graph builders, one for the if branch and one for the else branch.
576 """
577 if _driver_ver < 12080: 1jklmnopq
578 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else")
579 if _py_major_minor < (12, 8): 1jklmnopq
580 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else")
581 node_params = driver.CUgraphNodeParams() 1jklmnopq
582 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq
583 node_params.conditional.handle = handle 1jklmnopq
584 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq
585 node_params.conditional.size = 2 1jklmnopq
586 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq
587 return self._cond_with_params(node_params) 1jklmnopq
589 def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]:
590 """Adds a switch condition branch and returns new graph builders for all cases.
592 The resulting switch graph will execute the branch that matches the
593 case index of the conditional handle at runtime. If no match is found, no branch
594 will be executed.
596 The new builders inherit work dependencies from the original builder.
598 Parameters
599 ----------
600 handle : driver.CUgraphConditionalHandle
601 The handle to use for the switch conditional.
602 count : int
603 The number of cases to add to the switch conditional.
605 Returns
606 -------
607 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...]
608 A tuple of new graph builders, one for each branch.
610 """
611 if _driver_ver < 12080: 1wrstu
612 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch")
613 if _py_major_minor < (12, 8): 1wrstu
614 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch")
615 node_params = driver.CUgraphNodeParams() 1wrstu
616 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wrstu
617 node_params.conditional.handle = handle 1wrstu
618 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1wrstu
619 node_params.conditional.size = count 1wrstu
620 node_params.conditional.ctx = self._get_conditional_context() 1wrstu
621 return self._cond_with_params(node_params) 1wrstu
623 def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder:
624 """Adds a while loop and returns a new graph builder for it.
626 The resulting while loop graph will execute the branch repeatedly at runtime
627 until the conditional handle evaluates to false.
629 The new builder inherits work dependencies from the original builder.
631 Parameters
632 ----------
633 handle : driver.CUgraphConditionalHandle
634 The handle to use for the while loop.
636 Returns
637 -------
638 graph_builder : :obj:`~_graph.GraphBuilder`
639 The newly created while loop graph builder.
641 """
642 if _driver_ver < 12030: 1yzAB
643 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop")
644 if _py_major_minor < (12, 3): 1yzAB
645 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop")
646 node_params = driver.CUgraphNodeParams() 1yzAB
647 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1yzAB
648 node_params.conditional.handle = handle 1yzAB
649 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1yzAB
650 node_params.conditional.size = 1 1yzAB
651 node_params.conditional.ctx = self._get_conditional_context() 1yzAB
652 return self._cond_with_params(node_params)[0] 1yzAB
654 def close(self):
655 """Destroy the graph builder.
657 Closes the associated stream if we own it. Borrowed stream
658 object will instead have their references released.
660 """
661 self._mnff.close() 1Ivxbcdefghijklmnopqrstua
663 def add_child(self, child_graph: GraphBuilder):
664 """Adds the child :obj:`~_graph.GraphBuilder` builder into self.
666 The child graph builder will be added as a child node to the parent graph builder.
668 Parameters
669 ----------
670 child_graph : :obj:`~_graph.GraphBuilder`
671 The child graph builder. Must have finished building.
672 """
673 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): 1D
674 raise NotImplementedError(
675 f"Launching child graphs is not implemented for versions older than CUDA 12."
676 f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}"
677 )
679 if not child_graph._building_ended: 1D
680 raise ValueError("Child graph has not finished building.")
682 if not self.is_building: 1D
683 raise ValueError("Parent graph is not being built.")
685 stream_handle = self._mnff.stream.handle 1D
686 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1D
687 driver.cuStreamGetCaptureInfo(stream_handle)
688 )
690 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159
691 # for rationale
692 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1D
693 deps_info_update = [ 1D
694 [
695 handle_return(
696 driver.cuGraphAddChildGraphNode(
697 graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph
698 )
699 )
700 ]
701 ] + [None] * (len(deps_info_out) - 1)
702 handle_return( 1D
703 driver.cuStreamUpdateCaptureDependencies(
704 stream_handle,
705 *deps_info_update, # dependencies, edgeData
706 1,
707 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES,
708 )
709 )
712class Graph:
713 """Represents an executable graph.
715 A graph groups a set of CUDA kernels and other CUDA operations together and executes
716 them with a specified dependency tree. It speeds up the workflow by combining the
717 driver activities associated with CUDA kernel launches and CUDA API calls.
719 Graphs must be built using a :obj:`~_graph.GraphBuilder` object.
721 """
723 class _MembersNeededForFinalize:
724 __slots__ = "graph"
726 def __init__(self, graph_obj, graph):
727 self.graph = graph 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
728 weakref.finalize(graph_obj, self.close) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
730 def close(self):
731 if self.graph: 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
732 handle_return(driver.cuGraphExecDestroy(self.graph)) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
733 self.graph = None 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
735 __slots__ = ("__weakref__", "_mnff")
737 def __init__(self):
738 raise RuntimeError("directly constructing a Graph instance is not supported")
740 @classmethod
741 def _init(cls, graph):
742 self = cls.__new__(cls) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
743 self._mnff = Graph._MembersNeededForFinalize(self, graph) 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
744 return self 1DIwOvxPJ2U03V1WKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzABE
746 def close(self):
747 """Destroy the graph."""
748 self._mnff.close() 1IE
750 @property
751 def handle(self) -> driver.CUgraphExec:
752 """Return the underlying ``CUgraphExec`` object.
754 .. caution::
756 This handle is a Python object. To get the memory address of the underlying C
757 handle, call ``int()`` on the returned object.
759 """
760 return self._mnff.graph
762 def update(self, builder: GraphBuilder):
763 """Update the graph using new build configuration from the builder.
765 The topology of the provided builder must be identical to this graph.
767 Parameters
768 ----------
769 builder : :obj:`~_graph.GraphBuilder`
770 The builder to update the graph with.
772 """
773 if not builder._building_ended: 1w
774 raise ValueError("Graph has not finished building.")
776 # Update the graph with the new nodes from the builder
777 exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph)) 1w
778 if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS: 1w
779 raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}")
781 def upload(self, stream: Stream):
782 """Uploads the graph in a stream.
784 Parameters
785 ----------
786 stream : :obj:`~_stream.Stream`
787 The stream in which to upload the graph
789 """
790 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 1vJWKQXLRHFGYMSZNT
792 def launch(self, stream: Stream):
793 """Launches the graph in a stream.
795 Parameters
796 ----------
797 stream : :obj:`~_stream.Stream`
798 The stream in which to launch the graph
800 """
801 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 1DwvPJWKQXLRHFGYMSZNTbcdefghijklmnopqrstuyzAB