Coverage for cuda / core / _graph / __init__.py: 89.39%
330 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-25 01:07 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7import weakref
8from dataclasses import dataclass
9from typing import TYPE_CHECKING
11if TYPE_CHECKING:
12 from cuda.core._stream import Stream
14from cuda.core._utils.cuda_utils import (
15 driver,
16 get_binding_version,
17 handle_return,
18)
20_inited = False
21_driver_ver = None
24def _lazy_init():
25 global _inited
26 if _inited: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
27 return 1EJwv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
29 global _py_major_minor, _driver_ver
30 # binding availability depends on cuda-python version
31 _py_major_minor = get_binding_version() 1I
32 _driver_ver = handle_return(driver.cuDriverGetVersion()) 1I
33 _inited = True 1I
36@dataclass
37class GraphDebugPrintOptions:
38 """Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()`
40 Attributes
41 ----------
42 verbose : bool
43 Output all debug data as if every debug flag is enabled (Default to False)
44 runtime_types : bool
45 Use CUDA Runtime structures for output (Default to False)
46 kernel_node_params : bool
47 Adds kernel parameter values to output (Default to False)
48 memcpy_node_params : bool
49 Adds memcpy parameter values to output (Default to False)
50 memset_node_params : bool
51 Adds memset parameter values to output (Default to False)
52 host_node_params : bool
53 Adds host parameter values to output (Default to False)
54 event_node_params : bool
55 Adds event parameter values to output (Default to False)
56 ext_semas_signal_node_params : bool
57 Adds external semaphore signal parameter values to output (Default to False)
58 ext_semas_wait_node_params : bool
59 Adds external semaphore wait parameter values to output (Default to False)
60 kernel_node_attributes : bool
61 Adds kernel node attributes to output (Default to False)
62 handles : bool
63 Adds node handles and every kernel function handle to output (Default to False)
64 mem_alloc_node_params : bool
65 Adds memory alloc parameter values to output (Default to False)
66 mem_free_node_params : bool
67 Adds memory free parameter values to output (Default to False)
68 batch_mem_op_node_params : bool
69 Adds batch mem op parameter values to output (Default to False)
70 extra_topo_info : bool
71 Adds edge numbering information (Default to False)
72 conditional_node_params : bool
73 Adds conditional node parameter values to output (Default to False)
75 """
77 verbose: bool = False
78 runtime_types: bool = False
79 kernel_node_params: bool = False
80 memcpy_node_params: bool = False
81 memset_node_params: bool = False
82 host_node_params: bool = False
83 event_node_params: bool = False
84 ext_semas_signal_node_params: bool = False
85 ext_semas_wait_node_params: bool = False
86 kernel_node_attributes: bool = False
87 handles: bool = False
88 mem_alloc_node_params: bool = False
89 mem_free_node_params: bool = False
90 batch_mem_op_node_params: bool = False
91 extra_topo_info: bool = False
92 conditional_node_params: bool = False
94 def _to_flags(self) -> int:
95 """Convert options to CUDA driver API flags (internal use)."""
96 flags = 0 1.a
97 if self.verbose: 1.a
98 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 1.a
99 if self.runtime_types: 1.a
100 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a
101 if self.kernel_node_params: 1.a
102 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a
103 if self.memcpy_node_params: 1.a
104 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a
105 if self.memset_node_params: 1.a
106 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a
107 if self.host_node_params: 1.a
108 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a
109 if self.event_node_params: 1.a
110 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a
111 if self.ext_semas_signal_node_params: 1.a
112 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a
113 if self.ext_semas_wait_node_params: 1.a
114 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a
115 if self.kernel_node_attributes: 1.a
116 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a
117 if self.handles: 1.a
118 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 1.a
119 if self.mem_alloc_node_params: 1.a
120 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a
121 if self.mem_free_node_params: 1.a
122 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a
123 if self.batch_mem_op_node_params: 1.a
124 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a
125 if self.extra_topo_info: 1.a
126 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a
127 if self.conditional_node_params: 1.a
128 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a
129 return flags 1.a
132@dataclass
133class GraphCompleteOptions:
134 """Customizable options for :obj:`_graph.GraphBuilder.complete()`
136 Attributes
137 ----------
138 auto_free_on_launch : bool, optional
139 Automatically free memory allocated in a graph before relaunching. (Default to False)
140 upload_stream : Stream, optional
141 Stream to use to automatically upload the graph after completion. (Default to None)
142 device_launch : bool, optional
143 Configure the graph to be launchable from the device. This flag can only
144 be used on platforms which support unified addressing. This flag cannot be
145 used in conjunction with auto_free_on_launch. (Default to False)
146 use_node_priority : bool, optional
147 Run the graph using the per-node priority attributes rather than the
148 priority of the stream it is launched into. (Default to False)
150 """
152 auto_free_on_launch: bool = False
153 upload_stream: Stream | None = None
154 device_launch: bool = False
155 use_node_priority: bool = False
158def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph:
159 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
160 if options: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
161 flags = 0 1HFG6789!#$%'()*+,-D
162 if options.auto_free_on_launch: 1HFG6789!#$%'()*+,-D
163 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1HFG69%),-D
164 if options.upload_stream: 1HFG6789!#$%'()*+,-D
165 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 18#(+D
166 params.hUploadStream = options.upload_stream.handle 18#(+D
167 if options.device_launch: 1HFG6789!#$%'()*+,-D
168 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1$D
169 if options.use_node_priority: 1HFG6789!#$%'()*+,-D
170 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 17!'*,-D
171 params.flags = flags 1HFG6789!#$%'()*+,-D
173 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
174 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
175 raise RuntimeError(
176 "Instantiation failed for an unexpected reason which is described in the return value of the function."
177 )
178 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
179 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
180 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
181 raise RuntimeError(
182 "Instantiation for device launch failed because the graph contained an unsupported operation."
183 )
184 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
185 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
186 elif (
187 _py_major_minor >= (12, 8)
188 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
189 ):
190 raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
191 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D
192 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
193 return graph 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D
196class GraphBuilder:
197 """Represents a graph under construction.
199 A graph groups a set of CUDA kernels and other CUDA operations together and executes
200 them with a specified dependency tree. It speeds up the workflow by combining the
201 driver activities associated with CUDA kernel launches and CUDA API calls.
203 Directly creating a :obj:`~_graph.GraphBuilder` is not supported due
204 to ambiguity. New graph builders should instead be created through a
205 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object.
207 """
209 class _MembersNeededForFinalize:
210 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream")
212 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required):
213 self.stream = stream_obj 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
214 self.is_stream_owner = is_stream_owner 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
215 self.graph = None 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
216 self.conditional_graph = conditional_graph 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
217 self.is_join_required = is_join_required 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
218 weakref.finalize(graph_builder_obj, self.close) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
220 def close(self):
221 if self.stream: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
222 if not self.is_join_required: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
223 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
224 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
225 # Note how this condition only occures for the primary graph builder
226 # This is because calling cuStreamEndCapture streams that were split off of the primary
227 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED.
228 # Therefore, it is currently a requirement that users join all split graph builders
229 # before a graph builder can be clearly destroyed.
230 handle_return(driver.cuStreamEndCapture(self.stream.handle))
231 if self.is_stream_owner: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
232 self.stream.close() 1EJwIv5xK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
233 self.stream = None 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
234 if self.graph: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
235 handle_return(driver.cuGraphDestroy(self.graph)) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
236 self.graph = None 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
237 self.conditional_graph = None 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
239 __slots__ = ("__weakref__", "_building_ended", "_mnff")
241 def __init__(self):
242 raise NotImplementedError(
243 "directly creating a Graph object can be ambiguous. Please either "
244 "call Device.create_graph_builder() or stream.create_graph_builder()"
245 )
247 @classmethod
248 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False):
249 self = cls.__new__(cls) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
250 _lazy_init() 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
251 self._mnff = GraphBuilder._MembersNeededForFinalize( 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
252 self, stream, is_stream_owner, conditional_graph, is_join_required
253 )
255 self._building_ended = False 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
256 return self 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
258 @property
259 def stream(self) -> Stream:
260 """Returns the stream associated with the graph builder."""
261 return self._mnff.stream 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
263 @property
264 def is_join_required(self) -> bool:
265 """Returns True if this graph builder must be joined before building is ended."""
266 return self._mnff.is_join_required 1vxbcdefghijklmnopqrstua
268 def begin_building(self, mode="relaxed") -> GraphBuilder:
269 """Begins the building process.
271 Build `mode` for controlling interaction with other API calls must be one of the following:
273 - `global` : Prohibit potentially unsafe operations across all streams in the process.
274 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread.
275 - `relaxed` : The local thread is not prohibited from potentially unsafe operations.
277 Parameters
278 ----------
279 mode : str, optional
280 Build mode to control the interaction with other API calls that are porentially unsafe.
281 Default set to use relaxed.
283 """
284 if self._building_ended: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
285 raise RuntimeError("Cannot resume building after building has ended.") 1P
286 if mode not in ("global", "thread_local", "relaxed"): 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
287 raise ValueError(f"Unsupported build mode: {mode}") 14
288 if mode == "global": 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
289 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 123WXHYZ4
290 elif mode == "thread_local": 1EJwIv5xPKU0V1LQMRFGNSOTbcdefghijklmnopqrstuyzAB4Da
291 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 101QRGST4
292 elif mode == "relaxed": 1EJwIv5xPKUVLMFNObcdefghijklmnopqrstuyzAB4Da
293 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1EJwIv5xPKUVLMFNObcdefghijklmnopqrstuyzAB4Da
294 else:
295 raise ValueError(f"Unsupported build mode: {mode}")
297 if self._mnff.conditional_graph: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
298 handle_return( 1wbcdefghijklmnopqrstuyzABa
299 driver.cuStreamBeginCaptureToGraph(
300 self._mnff.stream.handle,
301 self._mnff.conditional_graph,
302 None, # dependencies
303 None, # dependencyData
304 0, # numDependencies
305 capture_mode,
306 )
307 )
308 else:
309 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
310 return self 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
312 @property
313 def is_building(self) -> bool:
314 """Returns True if the graph builder is currently building."""
315 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
316 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
317 return False 15
318 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
319 return True 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
320 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED:
321 raise RuntimeError(
322 "Build process encountered an error and has been invalidated. Build process must now be ended."
323 )
324 else:
325 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}")
327 def end_building(self) -> GraphBuilder:
328 """Ends the building process."""
329 if not self.is_building: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
330 raise RuntimeError("Graph builder is not building.")
331 if self._mnff.conditional_graph: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
332 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1wbcdefghijklmnopqrstuyzABa
333 else:
334 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
336 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to
337 # resume the build process after the first call to end_building()
338 self._building_ended = True 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
339 return self 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da
341 def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
342 """Completes the graph builder and returns the built :obj:`~_graph.Graph` object.
344 Parameters
345 ----------
346 options : :obj:`~_graph.GraphCompleteOptions`, optional
347 Customizable dataclass for the graph builder completion options.
349 Returns
350 -------
351 graph : :obj:`~_graph.Graph`
352 The newly built graph.
354 """
355 if not self._building_ended: 1EJwIvxPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzABD
356 raise RuntimeError("Graph has not finished building.") 1I
358 return _instantiate_graph(self._mnff.graph, options) 1EJwIvxPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzABD
360 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
361 """Generates a DOT debug file for the graph builder.
363 Parameters
364 ----------
365 path : str
366 File path to use for writting debug DOT output
367 options : :obj:`~_graph.GraphDebugPrintOptions`, optional
368 Customizable dataclass for the debug print options.
370 """
371 if not self._building_ended: 1a
372 raise RuntimeError("Graph has not finished building.")
373 flags = options._to_flags() if options else 0 1a
374 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) 1a
376 def split(self, count: int) -> tuple[GraphBuilder, ...]:
377 """Splits the original graph builder into multiple graph builders.
379 The new builders inherit work dependencies from the original builder.
380 The original builder is reused for the split and is returned first in the tuple.
382 Parameters
383 ----------
384 count : int
385 The number of graph builders to split the graph builder into.
387 Returns
388 -------
389 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...]
390 A tuple of split graph builders. The first graph builder in the tuple
391 is always the original graph builder.
393 """
394 if count < 2: 1vxbcdefghijklmnopqrstua
395 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1v
397 event = self._mnff.stream.record() 1vxbcdefghijklmnopqrstua
398 result = [self] 1vxbcdefghijklmnopqrstua
399 for i in range(count - 1): 1vxbcdefghijklmnopqrstua
400 stream = self._mnff.stream.device.create_stream() 1vxbcdefghijklmnopqrstua
401 stream.wait(event) 1vxbcdefghijklmnopqrstua
402 result.append( 1vxbcdefghijklmnopqrstua
403 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True)
404 )
405 event.close() 1vxbcdefghijklmnopqrstua
406 return result 1vxbcdefghijklmnopqrstua
408 @staticmethod
409 def join(*graph_builders) -> GraphBuilder:
410 """Joins multiple graph builders into a single graph builder.
412 The returned builder inherits work dependencies from the provided builders.
414 Parameters
415 ----------
416 *graph_builders : :obj:`~_graph.GraphBuilder`
417 The graph builders to join.
419 Returns
420 -------
421 graph_builder : :obj:`~_graph.GraphBuilder`
422 The newly joined graph builder.
424 """
425 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1vxbcdefghijklmnopqrstua
426 raise TypeError("All arguments must be GraphBuilder instances")
427 if len(graph_builders) < 2: 1vxbcdefghijklmnopqrstua
428 raise ValueError("Must join with at least two graph builders") 1v
430 # Discover the root builder others should join
431 root_idx = 0 1vxbcdefghijklmnopqrstua
432 for i, builder in enumerate(graph_builders): 1vxbcdefghijklmnopqrstua
433 if not builder.is_join_required: 1vxbcdefghijklmnopqrstua
434 root_idx = i 1vxbcdefghijklmnopqrstua
435 break 1vxbcdefghijklmnopqrstua
437 # Join all onto the root builder
438 root_bdr = graph_builders[root_idx] 1vxbcdefghijklmnopqrstua
439 for idx, builder in enumerate(graph_builders): 1vxbcdefghijklmnopqrstua
440 if idx == root_idx: 1vxbcdefghijklmnopqrstua
441 continue 1vxbcdefghijklmnopqrstua
442 root_bdr.stream.wait(builder.stream) 1vxbcdefghijklmnopqrstua
443 builder.close() 1vxbcdefghijklmnopqrstua
445 return root_bdr 1vxbcdefghijklmnopqrstua
447 def __cuda_stream__(self) -> tuple[int, int]:
448 """Return an instance of a __cuda_stream__ protocol."""
449 return self.stream.__cuda_stream__()
451 def _get_conditional_context(self) -> driver.CUcontext:
452 return self._mnff.stream.context.handle 1wbcdefghijklmnopqrstuyzABa
454 def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle:
455 """Creates a conditional handle for the graph builder.
457 Parameters
458 ----------
459 default_value : int, optional
460 The default value to assign to the conditional handle.
462 Returns
463 -------
464 handle : driver.CUgraphConditionalHandle
465 The newly created conditional handle.
467 """
468 if _driver_ver < 12030: 1wbcdefghijklmnopqrstuyzABa
469 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles")
470 if _py_major_minor < (12, 3): 1wbcdefghijklmnopqrstuyzABa
471 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles")
472 if default_value is not None: 1wbcdefghijklmnopqrstuyzABa
473 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wyzAB
474 else:
475 default_value = 0 1bcdefghijklmnopqrstua
476 flags = 0 1bcdefghijklmnopqrstua
478 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1wbcdefghijklmnopqrstuyzABa
479 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1wbcdefghijklmnopqrstuyzABa
480 raise RuntimeError("Cannot create a conditional handle when graph is not being built")
482 return handle_return( 1wbcdefghijklmnopqrstuyzABa
483 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags)
484 )
486 def _cond_with_params(self, node_params) -> GraphBuilder:
487 # Get current capture info to ensure we're in a valid state
488 status, _, graph, *deps_info, num_dependencies = handle_return( 1wbcdefghijklmnopqrstuyzABa
489 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)
490 )
491 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1wbcdefghijklmnopqrstuyzABa
492 raise RuntimeError("Cannot add conditional node when not actively capturing")
494 # Add the conditional node to the graph
495 deps_info_update = [ 1wbcdefghijklmnopqrstuyzABa
496 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))]
497 ] + [None] * (len(deps_info) - 1)
499 # Update the stream's capture dependencies
500 handle_return( 1wbcdefghijklmnopqrstuyzABa
501 driver.cuStreamUpdateCaptureDependencies(
502 self._mnff.stream.handle,
503 *deps_info_update, # dependencies, edgeData
504 1, # numDependencies
505 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES,
506 )
507 )
509 # Create new graph builders for each condition
510 return tuple( 1wbcdefghijklmnopqrstuyzABa
511 [
512 GraphBuilder._init(
513 stream=self._mnff.stream.device.create_stream(),
514 is_stream_owner=True,
515 conditional_graph=node_params.conditional.phGraph_out[i],
516 is_join_required=False,
517 )
518 for i in range(node_params.conditional.size)
519 ]
520 )
522 def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder:
523 """Adds an if condition branch and returns a new graph builder for it.
525 The resulting if graph will only execute the branch if the conditional
526 handle evaluates to true at runtime.
528 The new builder inherits work dependencies from the original builder.
530 Parameters
531 ----------
532 handle : driver.CUgraphConditionalHandle
533 The handle to use for the if conditional.
535 Returns
536 -------
537 graph_builder : :obj:`~_graph.GraphBuilder`
538 The newly created conditional graph builder.
540 """
541 if _driver_ver < 12030: 1bcdefghia
542 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if")
543 if _py_major_minor < (12, 3): 1bcdefghia
544 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if")
545 node_params = driver.CUgraphNodeParams() 1bcdefghia
546 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia
547 node_params.conditional.handle = handle 1bcdefghia
548 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia
549 node_params.conditional.size = 1 1bcdefghia
550 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia
551 return self._cond_with_params(node_params)[0] 1bcdefghia
553 def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]:
554 """Adds an if-else condition branch and returns new graph builders for both branches.
556 The resulting if graph will execute the branch if the conditional handle
557 evaluates to true at runtime, otherwise the else branch will execute.
559 The new builders inherit work dependencies from the original builder.
561 Parameters
562 ----------
563 handle : driver.CUgraphConditionalHandle
564 The handle to use for the if-else conditional.
566 Returns
567 -------
568 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, :obj:`~_graph.GraphBuilder`]
569 A tuple of two new graph builders, one for the if branch and one for the else branch.
571 """
572 if _driver_ver < 12080: 1jklmnopq
573 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else")
574 if _py_major_minor < (12, 8): 1jklmnopq
575 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else")
576 node_params = driver.CUgraphNodeParams() 1jklmnopq
577 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq
578 node_params.conditional.handle = handle 1jklmnopq
579 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq
580 node_params.conditional.size = 2 1jklmnopq
581 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq
582 return self._cond_with_params(node_params) 1jklmnopq
584 def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]:
585 """Adds a switch condition branch and returns new graph builders for all cases.
587 The resulting switch graph will execute the branch that matches the
588 case index of the conditional handle at runtime. If no match is found, no branch
589 will be executed.
591 The new builders inherit work dependencies from the original builder.
593 Parameters
594 ----------
595 handle : driver.CUgraphConditionalHandle
596 The handle to use for the switch conditional.
597 count : int
598 The number of cases to add to the switch conditional.
600 Returns
601 -------
602 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...]
603 A tuple of new graph builders, one for each branch.
605 """
606 if _driver_ver < 12080: 1wrstu
607 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch")
608 if _py_major_minor < (12, 8): 1wrstu
609 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch")
610 node_params = driver.CUgraphNodeParams() 1wrstu
611 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wrstu
612 node_params.conditional.handle = handle 1wrstu
613 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1wrstu
614 node_params.conditional.size = count 1wrstu
615 node_params.conditional.ctx = self._get_conditional_context() 1wrstu
616 return self._cond_with_params(node_params) 1wrstu
618 def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder:
619 """Adds a while loop and returns a new graph builder for it.
621 The resulting while loop graph will execute the branch repeatedly at runtime
622 until the conditional handle evaluates to false.
624 The new builder inherits work dependencies from the original builder.
626 Parameters
627 ----------
628 handle : driver.CUgraphConditionalHandle
629 The handle to use for the while loop.
631 Returns
632 -------
633 graph_builder : :obj:`~_graph.GraphBuilder`
634 The newly created while loop graph builder.
636 """
637 if _driver_ver < 12030: 1yzAB
638 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop")
639 if _py_major_minor < (12, 3): 1yzAB
640 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop")
641 node_params = driver.CUgraphNodeParams() 1yzAB
642 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1yzAB
643 node_params.conditional.handle = handle 1yzAB
644 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1yzAB
645 node_params.conditional.size = 1 1yzAB
646 node_params.conditional.ctx = self._get_conditional_context() 1yzAB
647 return self._cond_with_params(node_params)[0] 1yzAB
649 def close(self):
650 """Destroy the graph builder.
652 Closes the associated stream if we own it. Borrowed stream
653 object will instead have their references released.
655 """
656 self._mnff.close() 1Jvxbcdefghijklmnopqrstua
658 def add_child(self, child_graph: GraphBuilder):
659 """Adds the child :obj:`~_graph.GraphBuilder` builder into self.
661 The child graph builder will be added as a child node to the parent graph builder.
663 Parameters
664 ----------
665 child_graph : :obj:`~_graph.GraphBuilder`
666 The child graph builder. Must have finished building.
667 """
668 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): 1E
669 raise NotImplementedError(
670 f"Launching child graphs is not implemented for versions older than CUDA 12."
671 f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}"
672 )
674 if not child_graph._building_ended: 1E
675 raise ValueError("Child graph has not finished building.")
677 if not self.is_building: 1E
678 raise ValueError("Parent graph is not being built.")
680 stream_handle = self._mnff.stream.handle 1E
681 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1E
682 driver.cuStreamGetCaptureInfo(stream_handle)
683 )
685 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159
686 # for rationale
687 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1E
688 deps_info_update = [ 1E
689 [
690 handle_return(
691 driver.cuGraphAddChildGraphNode(
692 graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph
693 )
694 )
695 ]
696 ] + [None] * (len(deps_info_out) - 1)
697 handle_return( 1E
698 driver.cuStreamUpdateCaptureDependencies(
699 stream_handle,
700 *deps_info_update, # dependencies, edgeData
701 1,
702 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES,
703 )
704 )
707class Graph:
708 """Represents an executable graph.
710 A graph groups a set of CUDA kernels and other CUDA operations together and executes
711 them with a specified dependency tree. It speeds up the workflow by combining the
712 driver activities associated with CUDA kernel launches and CUDA API calls.
714 Graphs must be built using a :obj:`~_graph.GraphBuilder` object.
716 """
718 class _MembersNeededForFinalize:
719 __slots__ = "graph"
721 def __init__(self, graph_obj, graph):
722 self.graph = graph 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
723 weakref.finalize(graph_obj, self.close) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
725 def close(self):
726 if self.graph: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D gb
727 handle_return(driver.cuGraphExecDestroy(self.graph)) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D gb
728 self.graph = None 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D gb
730 __slots__ = ("__weakref__", "_mnff")
732 def __init__(self):
733 raise RuntimeError("directly constructing a Graph instance is not supported")
735 @classmethod
736 def _init(cls, graph):
737 self = cls.__new__(cls) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
738 self._mnff = Graph._MembersNeededForFinalize(self, graph) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
739 return self 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D
741 def close(self):
742 """Destroy the graph."""
743 self._mnff.close() 1JD
745 @property
746 def handle(self) -> driver.CUgraphExec:
747 """Return the underlying ``CUgraphExec`` object.
749 .. caution::
751 This handle is a Python object. To get the memory address of the underlying C
752 handle, call ``int()`` on the returned object.
754 """
755 return self._mnff.graph
757 def update(self, builder: GraphBuilder):
758 """Update the graph using new build configuration from the builder.
760 The topology of the provided builder must be identical to this graph.
762 Parameters
763 ----------
764 builder : :obj:`~_graph.GraphBuilder`
765 The builder to update the graph with.
767 """
768 if not builder._building_ended: 1w
769 raise ValueError("Graph has not finished building.")
771 # Update the graph with the new nodes from the builder
772 exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph)) 1w
773 if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS: 1w
774 raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}")
776 def upload(self, stream: Stream):
777 """Uploads the graph in a stream.
779 Parameters
780 ----------
781 stream : :obj:`~_stream.Stream`
782 The stream in which to upload the graph
784 """
785 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 1vKWLQXMRHFGYNSZOT/6:7;=?@9[!$%]')^*_`{
787 def launch(self, stream: Stream):
788 """Launches the graph in a stream.
790 Parameters
791 ----------
792 stream : :obj:`~_stream.Stream`
793 The stream in which to launch the graph
795 """
796 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 1EwvPKWLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB/6:78;=?@9[!#$%]'()^*+_`|}~{