Coverage for cuda / core / graph / _graph_builder.pyx: 88.28%
384 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-22 01:37 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5import weakref
6from dataclasses import dataclass
8from libc.stdint cimport intptr_t
10from cuda.bindings cimport cydriver
12from cuda.core.graph._graph_definition cimport GraphCondition
13from cuda.core.graph._utils cimport _attach_host_callback_to_graph
14from cuda.core._resource_handles cimport as_cu
15from cuda.core._stream cimport Stream
16from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
17from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
19from cuda.core._utils.cuda_utils import (
20 CUDAError,
21 driver,
22 handle_return,
23)
25__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions']
28@dataclass
29class GraphDebugPrintOptions:
30 """Options for debug_dot_print().
32 Attributes
33 ----------
34 verbose : bool
35 Output all debug data as if every debug flag is enabled (Default to False)
36 runtime_types : bool
37 Use CUDA Runtime structures for output (Default to False)
38 kernel_node_params : bool
39 Adds kernel parameter values to output (Default to False)
40 memcpy_node_params : bool
41 Adds memcpy parameter values to output (Default to False)
42 memset_node_params : bool
43 Adds memset parameter values to output (Default to False)
44 host_node_params : bool
45 Adds host parameter values to output (Default to False)
46 event_node_params : bool
47 Adds event parameter values to output (Default to False)
48 ext_semas_signal_node_params : bool
49 Adds external semaphore signal parameter values to output (Default to False)
50 ext_semas_wait_node_params : bool
51 Adds external semaphore wait parameter values to output (Default to False)
52 kernel_node_attributes : bool
53 Adds kernel node attributes to output (Default to False)
54 handles : bool
55 Adds node handles and every kernel function handle to output (Default to False)
56 mem_alloc_node_params : bool
57 Adds memory alloc parameter values to output (Default to False)
58 mem_free_node_params : bool
59 Adds memory free parameter values to output (Default to False)
60 batch_mem_op_node_params : bool
61 Adds batch mem op parameter values to output (Default to False)
62 extra_topo_info : bool
63 Adds edge numbering information (Default to False)
64 conditional_node_params : bool
65 Adds conditional node parameter values to output (Default to False)
67 """
69 verbose: bool = False
70 runtime_types: bool = False
71 kernel_node_params: bool = False
72 memcpy_node_params: bool = False
73 memset_node_params: bool = False
74 host_node_params: bool = False
75 event_node_params: bool = False
76 ext_semas_signal_node_params: bool = False
77 ext_semas_wait_node_params: bool = False
78 kernel_node_attributes: bool = False
79 handles: bool = False
80 mem_alloc_node_params: bool = False
81 mem_free_node_params: bool = False
82 batch_mem_op_node_params: bool = False
83 extra_topo_info: bool = False
84 conditional_node_params: bool = False
86 def _to_flags(self) -> int:
87 """Convert options to CUDA driver API flags (internal use)."""
88 flags = 0 2oba
89 if self.verbose: 2oba
90 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 2oba
91 if self.runtime_types: 2oba
92 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a
93 if self.kernel_node_params: 2oba
94 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a
95 if self.memcpy_node_params: 2oba
96 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a
97 if self.memset_node_params: 2oba
98 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a
99 if self.host_node_params: 2oba
100 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a
101 if self.event_node_params: 2oba
102 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a
103 if self.ext_semas_signal_node_params: 2oba
104 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a
105 if self.ext_semas_wait_node_params: 2oba
106 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a
107 if self.kernel_node_attributes: 2oba
108 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a
109 if self.handles: 2oba
110 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 2oba
111 if self.mem_alloc_node_params: 2oba
112 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a
113 if self.mem_free_node_params: 2oba
114 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a
115 if self.batch_mem_op_node_params: 2F oba
116 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a
117 if self.extra_topo_info: 2oba
118 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a
119 if self.conditional_node_params: 2oba
120 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a
121 return flags 2oba
124@dataclass
125class GraphCompleteOptions:
126 """Options for graph instantiation.
128 Attributes
129 ----------
130 auto_free_on_launch : bool, optional
131 Automatically free memory allocated in a graph before relaunching. (Default to False)
132 upload_stream : Stream, optional
133 Stream to use to automatically upload the graph after completion. (Default to None)
134 device_launch : bool, optional
135 Configure the graph to be launchable from the device. This flag can only
136 be used on platforms which support unified addressing. This flag cannot be
137 used in conjunction with auto_free_on_launch. (Default to False)
138 use_node_priority : bool, optional
139 Run the graph using the per-node priority attributes rather than the
140 priority of the stream it is launched into. (Default to False)
142 """
144 auto_free_on_launch: bool = False
145 upload_stream: Stream | None = None
146 device_launch: bool = False
147 use_node_priority: bool = False
150def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> "Graph":
151 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
152 if options: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
153 flags = 0 1%'()*+,-./:;=?@LJKD
154 if options.auto_free_on_launch: 1%'()*+,-./:;=?@LJKD
155 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1%)-:?@LJKD
156 if options.upload_stream: 1%'()*+,-./:;=?@LJKD
157 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1(+/=D
158 params.hUploadStream = options.upload_stream.handle 1(+/=D
159 if options.device_launch: 1%'()*+,-./:;=?@LJKD
160 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1,D
161 if options.use_node_priority: 1%'()*+,-./:;=?@LJKD
162 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 1'*.;?@D
163 params.flags = flags 1%'()*+,-./:;=?@LJKD
165 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
166 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
167 raise RuntimeError(
168 "Instantiation failed for an unexpected reason which is described in the return value of the function."
169 )
170 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
171 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
172 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
173 raise RuntimeError(
174 "Instantiation for device launch failed because the graph contained an unsupported operation."
175 )
176 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
177 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
178 elif ( 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
179 cy_binding_version() >= (12, 8, 0) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
180 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
181 ):
182 raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
183 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
184 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
185 return graph 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
188class GraphBuilder:
189 """A graph under construction by stream capture.
191 A graph groups a set of CUDA kernels and other CUDA operations together and executes
192 them with a specified dependency tree. It speeds up the workflow by combining the
193 driver activities associated with CUDA kernel launches and CUDA API calls.
195 Directly creating a :obj:`~graph.GraphBuilder` is not supported due
196 to ambiguity. New graph builders should instead be created through a
197 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object.
199 """
201 class _MembersNeededForFinalize:
202 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream")
204 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required):
205 self.stream = stream_obj 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
206 self.is_stream_owner = is_stream_owner 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
207 self.graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
208 self.conditional_graph = conditional_graph 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
209 self.is_join_required = is_join_required 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
210 weakref.finalize(graph_builder_obj, self.close) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
212 def close(self):
213 if self.stream: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
214 if not self.is_join_required: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
215 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
216 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
217 # Note how this condition only occures for the primary graph builder
218 # This is because calling cuStreamEndCapture streams that were split off of the primary
219 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED.
220 # Therefore, it is currently a requirement that users join all split graph builders
221 # before a graph builder can be clearly destroyed.
222 handle_return(driver.cuStreamEndCapture(self.stream.handle))
223 if self.is_stream_owner: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
224 self.stream.close() 1UCA#BOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
225 self.stream = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
226 if self.graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
227 handle_return(driver.cuGraphDestroy(self.graph)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
228 self.graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
229 self.conditional_graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
231 __slots__ = ("__weakref__", "_building_ended", "_mnff")
233 def __init__(self):
234 raise NotImplementedError(
235 "directly creating a Graph object can be ambiguous. Please either "
236 "call Device.create_graph_builder() or stream.create_graph_builder()"
237 )
239 @classmethod
240 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False):
241 self = cls.__new__(cls) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
242 self._mnff = GraphBuilder._MembersNeededForFinalize( 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
243 self, stream, is_stream_owner, conditional_graph, is_join_required 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
244 )
246 self._building_ended = False 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
247 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
249 @property
250 def stream(self) -> Stream:
251 """Returns the stream associated with the graph builder."""
252 return self._mnff.stream 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
254 @property
255 def is_join_required(self) -> bool:
256 """Returns True if this graph builder must be joined before building is ended."""
257 return self._mnff.is_join_required 1ABbcdefghijklmnopqrstua
259 def begin_building(self, mode="relaxed") -> GraphBuilder:
260 """Begins the building process.
262 Build `mode` for controlling interaction with other API calls must be one of the following:
264 - `global` : Prohibit potentially unsafe operations across all streams in the process.
265 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread.
266 - `relaxed` : The local thread is not prohibited from potentially unsafe operations.
268 Parameters
269 ----------
270 mode : str, optional
271 Build mode to control the interaction with other API calls that are porentially unsafe.
272 Default set to use relaxed.
274 """
275 if self._building_ended: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
276 raise RuntimeError("Cannot resume building after building has ended.") 1V
277 if mode not in ("global", "thread_local", "relaxed"): 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
278 raise ValueError(f"Unsupported build mode: {mode}") 1!
279 if mode == "global": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
280 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 18923L45!
281 elif mode == "thread_local": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz0617QWRXJKSYTZvIENM!Da
282 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 167WXKYZ!
283 elif mode == "relaxed": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvIENM!Da
284 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvIENM!Da
285 else:
286 raise ValueError(f"Unsupported build mode: {mode}")
288 if self._mnff.conditional_graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
289 handle_return( 1bcdefghijklmnopqrstuwxyzva
290 driver.cuStreamBeginCaptureToGraph( 1bcdefghijklmnopqrstuwxyzva
291 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva
292 self._mnff.conditional_graph, 1bcdefghijklmnopqrstuwxyzva
293 None, # dependencies
294 None, # dependencyData
295 0, # numDependencies
296 capture_mode, 1bcdefghijklmnopqrstuwxyzva
297 )
298 )
299 else:
300 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
301 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
303 @property
304 def is_building(self) -> bool:
305 """Returns True if the graph builder is currently building."""
306 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
307 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
308 return False 1#
309 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
310 return True 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
311 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED:
312 raise RuntimeError(
313 "Build process encountered an error and has been invalidated. Build process must now be ended."
314 )
315 else:
316 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}")
318 def end_building(self) -> GraphBuilder:
319 """Ends the building process."""
320 if not self.is_building: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
321 raise RuntimeError("Graph builder is not building.")
322 if self._mnff.conditional_graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
323 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1bcdefghijklmnopqrstuwxyzva
324 else:
325 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
327 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to
328 # resume the build process after the first call to end_building()
329 self._building_ended = True 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
330 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENM!Da
332 def complete(self, options: GraphCompleteOptions | None = None) -> "Graph":
333 """Completes the graph builder and returns the built :obj:`~graph.Graph` object.
335 Parameters
336 ----------
337 options : :obj:`~graph.GraphCompleteOptions`, optional
338 Customizable dataclass for the graph builder completion options.
340 Returns
341 -------
342 graph : :obj:`~graph.Graph`
343 The newly built graph.
345 """
346 if not self._building_ended: 1GHUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENMD
347 raise RuntimeError("Graph has not finished building.") 1U
349 return _instantiate_graph(self._mnff.graph, options) 1GHUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIENMD
351 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
352 """Generates a DOT debug file for the graph builder.
354 Parameters
355 ----------
356 path : str
357 File path to use for writting debug DOT output
358 options : :obj:`~graph.GraphDebugPrintOptions`, optional
359 Customizable dataclass for the debug print options.
361 """
362 if not self._building_ended: 1a
363 raise RuntimeError("Graph has not finished building.")
364 flags = options._to_flags() if options else 0 1a
365 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) 1a
367 def split(self, count: int) -> tuple[GraphBuilder, ...]:
368 """Splits the original graph builder into multiple graph builders.
370 The new builders inherit work dependencies from the original builder.
371 The original builder is reused for the split and is returned first in the tuple.
373 Parameters
374 ----------
375 count : int
376 The number of graph builders to split the graph builder into.
378 Returns
379 -------
380 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
381 A tuple of split graph builders. The first graph builder in the tuple
382 is always the original graph builder.
384 """
385 if count < 2: 1ABbcdefghijklmnopqrstua
386 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1A
388 event = self._mnff.stream.record() 1ABbcdefghijklmnopqrstua
389 result = [self] 1ABbcdefghijklmnopqrstua
390 for i in range(count - 1): 1ABbcdefghijklmnopqrstua
391 stream = self._mnff.stream.device.create_stream() 1ABbcdefghijklmnopqrstua
392 stream.wait(event) 1ABbcdefghijklmnopqrstua
393 result.append( 1ABbcdefghijklmnopqrstua
394 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) 1ABbcdefghijklmnopqrstua
395 )
396 event.close() 1ABbcdefghijklmnopqrstua
397 return tuple(result) 1ABbcdefghijklmnopqrstua
399 @staticmethod
400 def join(*graph_builders) -> GraphBuilder:
401 """Joins multiple graph builders into a single graph builder.
403 The returned builder inherits work dependencies from the provided builders.
405 Parameters
406 ----------
407 *graph_builders : :obj:`~graph.GraphBuilder`
408 The graph builders to join.
410 Returns
411 -------
412 graph_builder : :obj:`~graph.GraphBuilder`
413 The newly joined graph builder.
415 """
416 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1ABbcdefghijklmnopqrstua
417 raise TypeError("All arguments must be GraphBuilder instances")
418 if len(graph_builders) < 2: 1ABbcdefghijklmnopqrstua
419 raise ValueError("Must join with at least two graph builders") 1A
421 # Discover the root builder others should join
422 root_idx = 0 1ABbcdefghijklmnopqrstua
423 for i, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua
424 if not builder.is_join_required: 1ABbcdefghijklmnopqrstua
425 root_idx = i 1ABbcdefghijklmnopqrstua
426 break 1ABbcdefghijklmnopqrstua
428 # Join all onto the root builder
429 root_bdr = graph_builders[root_idx] 1ABbcdefghijklmnopqrstua
430 for idx, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua
431 if idx == root_idx: 1ABbcdefghijklmnopqrstua
432 continue 1ABbcdefghijklmnopqrstua
433 root_bdr.stream.wait(builder.stream) 1ABbcdefghijklmnopqrstua
434 builder.close() 1ABbcdefghijklmnopqrstua
436 return root_bdr 1ABbcdefghijklmnopqrstua
438 def __cuda_stream__(self) -> tuple[int, int]:
439 """Return an instance of a __cuda_stream__ protocol."""
440 return self.stream.__cuda_stream__()
442 def _get_conditional_context(self) -> driver.CUcontext:
443 return self._mnff.stream.context.handle 1bcdefghijklmnopqrstuwxyzva
445 def create_condition(self, default_value=None) -> GraphCondition:
446 """Create a condition variable for use with conditional nodes.
448 The returned :class:`GraphCondition` object is passed to conditional-node
449 builder methods (:meth:`if_then`, :meth:`if_else`, :meth:`while_loop`,
450 :meth:`switch`). Its value is controlled at runtime by device code via
451 ``cudaGraphSetConditional``.
453 Parameters
454 ----------
455 default_value : int, optional
456 The default value to assign to the condition. If None, no
457 default is assigned.
459 Returns
460 -------
461 GraphCondition
462 A condition variable for controlling conditional execution.
463 """
464 if cy_driver_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva
465 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles")
466 if cy_binding_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva
467 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles")
468 if default_value is not None: 1bcdefghijklmnopqrstuwxyzva
469 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wxyzv
470 else:
471 default_value = 0 1bcdefghijklmnopqrstua
472 flags = 0 1bcdefghijklmnopqrstua
474 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1bcdefghijklmnopqrstuwxyzva
475 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva
476 raise RuntimeError("Cannot create a condition when graph is not being built")
478 raw_handle = handle_return( 1bcdefghijklmnopqrstuwxyzva
479 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 1bcdefghijklmnopqrstuwxyzva
480 )
481 return GraphCondition._from_handle(<cydriver.CUgraphConditionalHandle><intptr_t>int(raw_handle)) 1bcdefghijklmnopqrstuwxyzva
483 def _cond_with_params(self, node_params) -> tuple:
484 # Get current capture info to ensure we're in a valid state
485 status, _, graph, *deps_info, num_dependencies = handle_return( 1bcdefghijklmnopqrstuwxyzva
486 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) 1bcdefghijklmnopqrstuwxyzva
487 )
488 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva
489 raise RuntimeError("Cannot add conditional node when not actively capturing")
491 # Add the conditional node to the graph
492 deps_info_update = [ 1bcdefghijklmnopqrstuwxyzva
493 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 1bcdefghijklmnopqrstuwxyzva
494 ] + [None] * (len(deps_info) - 1) 1bcdefghijklmnopqrstuwxyzva
496 # Update the stream's capture dependencies
497 handle_return( 1bcdefghijklmnopqrstuwxyzva
498 driver.cuStreamUpdateCaptureDependencies( 1bcdefghijklmnopqrstuwxyzva
499 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva
500 *deps_info_update, # dependencies, edgeData 1bcdefghijklmnopqrstuwxyzva
501 1, # numDependencies
502 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1bcdefghijklmnopqrstuwxyzva
503 )
504 )
506 # Create new graph builders for each condition
507 return tuple( 1bcdefghijklmnopqrstuwxyzva
508 [ 1bcdefghijklmnopqrstuwxyzva
509 GraphBuilder._init( 1bcdefghijklmnopqrstuwxyzva
510 stream=self._mnff.stream.device.create_stream(), 1bcdefghijklmnopqrstuwxyzva
511 is_stream_owner=True,
512 conditional_graph=node_params.conditional.phGraph_out[i], 1bcdefghijklmnopqrstuwxyzva
513 is_join_required=False, 1bcdefghijklmnopqrstuwxyzva
514 )
515 for i in range(node_params.conditional.size) 1bcdefghijklmnopqrstuwxyzva
516 ]
517 )
519 def if_then(self, condition: GraphCondition) -> GraphBuilder:
520 """Adds an if condition branch and returns a new graph builder for it.
522 The resulting if graph will only execute the branch if the
523 condition evaluates to true at runtime.
525 The new builder inherits work dependencies from the original builder.
527 Parameters
528 ----------
529 condition : :class:`~graph.GraphCondition`
530 The condition variable from :meth:`create_condition` controlling
531 whether the branch executes.
533 Returns
534 -------
535 graph_builder : :obj:`~graph.GraphBuilder`
536 The newly created conditional graph builder.
538 """
539 if cy_driver_version() < (12, 3, 0): 1bcdefghia
540 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if")
541 if cy_binding_version() < (12, 3, 0): 1bcdefghia
542 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if")
543 if not isinstance(condition, GraphCondition): 1bcdefghia
544 raise TypeError(
545 f"condition must be a GraphCondition object (from "
546 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
547 node_params = driver.CUgraphNodeParams() 1bcdefghia
548 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia
549 node_params.conditional.handle = condition.handle 1bcdefghia
550 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia
551 node_params.conditional.size = 1 1bcdefghia
552 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia
553 return self._cond_with_params(node_params)[0] 1bcdefghia
555 def if_else(self, condition: GraphCondition) -> tuple[GraphBuilder, GraphBuilder]:
556 """Adds an if-else condition branch and returns new graph builders for both branches.
558 The resulting if graph will execute the branch if the condition
559 evaluates to true at runtime, otherwise the else branch will execute.
561 The new builders inherit work dependencies from the original builder.
563 Parameters
564 ----------
565 condition : :class:`~graph.GraphCondition`
566 The condition variable from :meth:`create_condition` controlling
567 which branch executes.
569 Returns
570 -------
571 graph_builders : tuple[:obj:`~graph.GraphBuilder`, :obj:`~graph.GraphBuilder`]
572 A tuple of two new graph builders, one for the if branch and one for the else branch.
574 """
575 if cy_driver_version() < (12, 8, 0): 1jklmnopq
576 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else")
577 if cy_binding_version() < (12, 8, 0): 1jklmnopq
578 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else")
579 if not isinstance(condition, GraphCondition): 1jklmnopq
580 raise TypeError(
581 f"condition must be a GraphCondition object (from "
582 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
583 node_params = driver.CUgraphNodeParams() 1jklmnopq
584 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq
585 node_params.conditional.handle = condition.handle 1jklmnopq
586 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq
587 node_params.conditional.size = 2 1jklmnopq
588 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq
589 return self._cond_with_params(node_params) 1jklmnopq
591 def switch(self, condition: GraphCondition, count: int) -> tuple[GraphBuilder, ...]:
592 """Adds a switch condition branch and returns new graph builders for all cases.
594 The resulting switch graph will execute the branch whose case index
595 matches the value of the condition at runtime. If no match is found, no
596 branch will be executed.
598 The new builders inherit work dependencies from the original builder.
600 Parameters
601 ----------
602 condition : :class:`~graph.GraphCondition`
603 The condition variable from :meth:`create_condition` selecting
604 which case executes.
605 count : int
606 The number of cases to add to the switch conditional.
608 Returns
609 -------
610 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
611 A tuple of new graph builders, one for each branch.
613 """
614 if cy_driver_version() < (12, 8, 0): 1rstuv
615 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch")
616 if cy_binding_version() < (12, 8, 0): 1rstuv
617 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch")
618 if not isinstance(condition, GraphCondition): 1rstuv
619 raise TypeError(
620 f"condition must be a GraphCondition object (from "
621 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
622 node_params = driver.CUgraphNodeParams() 1rstuv
623 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1rstuv
624 node_params.conditional.handle = condition.handle 1rstuv
625 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1rstuv
626 node_params.conditional.size = count 1rstuv
627 node_params.conditional.ctx = self._get_conditional_context() 1rstuv
628 return self._cond_with_params(node_params) 1rstuv
630 def while_loop(self, condition: GraphCondition) -> GraphBuilder:
631 """Adds a while loop and returns a new graph builder for it.
633 The resulting while loop graph will execute the branch repeatedly at runtime
634 until the condition evaluates to false.
636 The new builder inherits work dependencies from the original builder.
638 Parameters
639 ----------
640 condition : :class:`~graph.GraphCondition`
641 The condition variable from :meth:`create_condition` controlling
642 loop continuation.
644 Returns
645 -------
646 graph_builder : :obj:`~graph.GraphBuilder`
647 The newly created while loop graph builder.
649 """
650 if cy_driver_version() < (12, 3, 0): 1wxyz
651 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop")
652 if cy_binding_version() < (12, 3, 0): 1wxyz
653 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop")
654 if not isinstance(condition, GraphCondition): 1wxyz
655 raise TypeError(
656 f"condition must be a GraphCondition object (from "
657 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
658 node_params = driver.CUgraphNodeParams() 1wxyz
659 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wxyz
660 node_params.conditional.handle = condition.handle 1wxyz
661 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1wxyz
662 node_params.conditional.size = 1 1wxyz
663 node_params.conditional.ctx = self._get_conditional_context() 1wxyz
664 return self._cond_with_params(node_params)[0] 1wxyz
666 def close(self):
667 """Destroy the graph builder.
669 Closes the associated stream if we own it. Borrowed stream
670 object will instead have their references released.
672 """
673 self._mnff.close() 1ABPbcdefghijklmnopqrstua
675 def embed(self, child: GraphBuilder):
676 """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node.
678 Parameters
679 ----------
680 child : :obj:`~graph.GraphBuilder`
681 The child graph builder. Must have finished building.
682 """
683 if not child._building_ended: 1C
684 raise ValueError("Child graph has not finished building.")
686 if not self.is_building: 1C
687 raise ValueError("Parent graph is not being built.")
689 stream_handle = self._mnff.stream.handle 1C
690 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1C
691 driver.cuStreamGetCaptureInfo(stream_handle) 1C
692 )
694 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159
695 # for rationale
696 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1C
697 deps_info_update = [ 1C
698 [ 1C
699 handle_return( 1C
700 driver.cuGraphAddChildGraphNode( 1C
701 graph_out, *deps_info_trimmed, num_dependencies_out, child._mnff.graph 1C
702 )
703 )
704 ]
705 ] + [None] * (len(deps_info_out) - 1) 1C
706 handle_return( 1C
707 driver.cuStreamUpdateCaptureDependencies( 1C
708 stream_handle, 1C
709 *deps_info_update, # dependencies, edgeData
710 1,
711 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1C
712 )
713 )
715 def callback(self, fn, *, user_data=None):
716 """Add a host callback to the graph during stream capture.
718 The callback runs on the host CPU when the graph reaches this point
719 in execution. Two modes are supported:
721 - **Python callable**: Pass any callable. The GIL is acquired
722 automatically. The callable must take no arguments; use closures
723 or ``functools.partial`` to bind state.
724 - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance.
725 The function receives a single ``void*`` argument (the
726 ``user_data``). The caller must keep the ctypes wrapper alive
727 for the lifetime of the graph.
729 .. warning::
731 Callbacks must not call CUDA API functions. Doing so may
732 deadlock or corrupt driver state.
734 Parameters
735 ----------
736 fn : callable or ctypes function pointer
737 The callback function.
738 user_data : int or bytes-like, optional
739 Only for ctypes function pointers. If ``int``, passed as a raw
740 pointer (caller manages lifetime). If bytes-like, the data is
741 copied and its lifetime is tied to the graph.
742 """
743 cdef Stream stream = <Stream>self._mnff.stream 1GH
744 cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) 1GH
745 cdef cydriver.CUstreamCaptureStatus capture_status
746 cdef cydriver.CUgraph c_graph = NULL 1GH
748 with nogil: 1GH
749 IF CUDA_CORE_BUILD_MAJOR >= 13:
750 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 1GH
751 c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL))
752 ELSE:
753 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo(
754 c_stream, &capture_status, NULL, &c_graph, NULL, NULL))
756 if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GH
757 raise RuntimeError("Cannot add callback when graph is not being built")
759 cdef cydriver.CUhostFn c_fn
760 cdef void* c_user_data = NULL 1GH
761 _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) 1GH
763 with nogil: 1GH
764 HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) 1GH
767class Graph:
768 """An executable graph.
770 A graph groups a set of CUDA kernels and other CUDA operations together and executes
771 them with a specified dependency tree. It speeds up the workflow by combining the
772 driver activities associated with CUDA kernel launches and CUDA API calls.
774 Graphs must be built using a :obj:`~graph.GraphBuilder` object.
776 """
778 class _MembersNeededForFinalize:
779 __slots__ = "graph"
781 def __init__(self, graph_obj, graph):
782 self.graph = graph 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
783 weakref.finalize(graph_obj, self.close) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
785 def close(self):
786 if self.graph: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
787 handle_return(driver.cuGraphExecDestroy(self.graph)) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
788 self.graph = None 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
790 __slots__ = ("__weakref__", "_mnff")
792 def __init__(self):
793 raise RuntimeError("directly constructing a Graph instance is not supported")
795 @classmethod
796 def _init(cls, graph):
797 self = cls.__new__(cls) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
798 self._mnff = Graph._MembersNeededForFinalize(self, graph) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
799 return self 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ E N M D
801 def close(self):
802 """Destroy the graph."""
803 self._mnff.close() 1PD
805 @property
806 def handle(self) -> driver.CUgraphExec:
807 """Return the underlying ``CUgraphExec`` object.
809 .. caution::
811 This handle is a Python object. To get the memory address of the underlying C
812 handle, call ``int()`` on the returned object.
814 """
815 return self._mnff.graph
817 def update(self, source: "GraphBuilder | GraphDefinition") -> None:
818 """Update the graph using a new graph definition.
820 The topology of the provided source must be identical to this graph.
822 Parameters
823 ----------
824 source : :obj:`~graph.GraphBuilder` or :obj:`~graph.GraphDefinition`
825 The graph definition to update from. A GraphBuilder must have
826 finished building.
828 """
829 from cuda.core.graph import GraphDefinition 1vI$ENM
831 cdef cydriver.CUgraph cu_graph
832 cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph) 1vI$ENM
834 if isinstance(source, GraphBuilder): 1vI$ENM
835 if not source._building_ended: 1vIEN
836 raise ValueError("Graph has not finished building.") 1N
837 cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph) 1vIE
838 elif isinstance(source, GraphDefinition): 1$M
839 cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle) 1$
840 else:
841 raise TypeError( 1M
842 f"expected GraphBuilder or GraphDefinition, got {type(source).__name__}") 1M
844 cdef cydriver.CUgraphExecUpdateResultInfo result_info
845 cdef cydriver.CUresult err
846 with nogil: 1vI$E
847 err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) 1vI$E
848 if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: 1vI$E
849 reason = driver.CUgraphExecUpdateResult(result_info.result) 1E
850 msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" 1E
851 raise CUDAError(msg) 1E
852 HANDLE_RETURN(err) 1vI$
854 def upload(self, stream: Stream):
855 """Uploads the graph in a stream.
857 Parameters
858 ----------
859 stream : :obj:`~_stream.Stream`
860 The stream in which to upload the graph
862 """
863 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 2A O [ % ] ' ^ _ ` { ) | * , - } . : ~ ; abbbcb2 Q W 3 R X L J K 4 S Y 5 T Z
865 def launch(self, stream: Stream):
866 """Launches the graph in a stream.
868 Parameters
869 ----------
870 stream : :obj:`~_stream.Stream`
871 The stream in which to launch the graph.
873 """
874 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 2G H C A V O b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = dbabbbebfbgbhbcbibjbkblbmbnb2 Q W 3 R X L J K 4 S Y 5 T Z v I $