Coverage for cuda / core / graph / _graph_builder.pyx: 90.98%
366 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5import weakref
6from dataclasses import dataclass
8from libc.stdint cimport intptr_t
10from cuda.bindings cimport cydriver
12from cuda.core.graph._utils cimport _attach_host_callback_to_graph
13from cuda.core._resource_handles cimport as_cu
14from cuda.core._stream cimport Stream
15from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
16from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
18from cuda.core._utils.cuda_utils import (
19 CUDAError,
20 driver,
21 handle_return,
22)
24@dataclass
25class GraphDebugPrintOptions:
26 """Options for debug_dot_print().
28 Attributes
29 ----------
30 verbose : bool
31 Output all debug data as if every debug flag is enabled (Default to False)
32 runtime_types : bool
33 Use CUDA Runtime structures for output (Default to False)
34 kernel_node_params : bool
35 Adds kernel parameter values to output (Default to False)
36 memcpy_node_params : bool
37 Adds memcpy parameter values to output (Default to False)
38 memset_node_params : bool
39 Adds memset parameter values to output (Default to False)
40 host_node_params : bool
41 Adds host parameter values to output (Default to False)
42 event_node_params : bool
43 Adds event parameter values to output (Default to False)
44 ext_semas_signal_node_params : bool
45 Adds external semaphore signal parameter values to output (Default to False)
46 ext_semas_wait_node_params : bool
47 Adds external semaphore wait parameter values to output (Default to False)
48 kernel_node_attributes : bool
49 Adds kernel node attributes to output (Default to False)
50 handles : bool
51 Adds node handles and every kernel function handle to output (Default to False)
52 mem_alloc_node_params : bool
53 Adds memory alloc parameter values to output (Default to False)
54 mem_free_node_params : bool
55 Adds memory free parameter values to output (Default to False)
56 batch_mem_op_node_params : bool
57 Adds batch mem op parameter values to output (Default to False)
58 extra_topo_info : bool
59 Adds edge numbering information (Default to False)
60 conditional_node_params : bool
61 Adds conditional node parameter values to output (Default to False)
63 """
65 verbose: bool = False
66 runtime_types: bool = False
67 kernel_node_params: bool = False
68 memcpy_node_params: bool = False
69 memset_node_params: bool = False
70 host_node_params: bool = False
71 event_node_params: bool = False
72 ext_semas_signal_node_params: bool = False
73 ext_semas_wait_node_params: bool = False
74 kernel_node_attributes: bool = False
75 handles: bool = False
76 mem_alloc_node_params: bool = False
77 mem_free_node_params: bool = False
78 batch_mem_op_node_params: bool = False
79 extra_topo_info: bool = False
80 conditional_node_params: bool = False
82 def _to_flags(self) -> int:
83 """Convert options to CUDA driver API flags (internal use)."""
84 flags = 0 2nba
85 if self.verbose: 2nba
86 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 2nba
87 if self.runtime_types: 2nba
88 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a
89 if self.kernel_node_params: 2nba
90 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a
91 if self.memcpy_node_params: 2nba
92 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a
93 if self.memset_node_params: 2nba
94 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a
95 if self.host_node_params: 2nba
96 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a
97 if self.event_node_params: 2nba
98 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a
99 if self.ext_semas_signal_node_params: 2nba
100 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a
101 if self.ext_semas_wait_node_params: 2nba
102 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a
103 if self.kernel_node_attributes: 2nba
104 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a
105 if self.handles: 2nba
106 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 2nba
107 if self.mem_alloc_node_params: 2nba
108 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a
109 if self.mem_free_node_params: 2nba
110 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a
111 if self.batch_mem_op_node_params: 2nba
112 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a
113 if self.extra_topo_info: 2nba
114 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a
115 if self.conditional_node_params: 2I nba
116 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a
117 return flags 2nba
120@dataclass
121class GraphCompleteOptions:
122 """Options for graph instantiation.
124 Attributes
125 ----------
126 auto_free_on_launch : bool, optional
127 Automatically free memory allocated in a graph before relaunching. (Default to False)
128 upload_stream : Stream, optional
129 Stream to use to automatically upload the graph after completion. (Default to None)
130 device_launch : bool, optional
131 Configure the graph to be launchable from the device. This flag can only
132 be used on platforms which support unified addressing. This flag cannot be
133 used in conjunction with auto_free_on_launch. (Default to False)
134 use_node_priority : bool, optional
135 Run the graph using the per-node priority attributes rather than the
136 priority of the stream it is launched into. (Default to False)
138 """
140 auto_free_on_launch: bool = False
141 upload_stream: Stream | None = None
142 device_launch: bool = False
143 use_node_priority: bool = False
146def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> "Graph":
147 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
148 if options: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
149 flags = 0 1%'()*+,-./:;=?@LJKD
150 if options.auto_free_on_launch: 1%'()*+,-./:;=?@LJKD
151 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1%)-:?@LJKD
152 if options.upload_stream: 1%'()*+,-./:;=?@LJKD
153 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1(+/=D
154 params.hUploadStream = options.upload_stream.handle 1(+/=D
155 if options.device_launch: 1%'()*+,-./:;=?@LJKD
156 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1,D
157 if options.use_node_priority: 1%'()*+,-./:;=?@LJKD
158 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 1'*.;?@D
159 params.flags = flags 1%'()*+,-./:;=?@LJKD
161 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
162 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
163 raise RuntimeError(
164 "Instantiation failed for an unexpected reason which is described in the return value of the function."
165 )
166 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
167 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
168 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
169 raise RuntimeError(
170 "Instantiation for device launch failed because the graph contained an unsupported operation."
171 )
172 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
173 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
174 elif ( 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
175 cy_binding_version() >= (12, 8, 0) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
176 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
177 ):
178 raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
179 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
180 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
181 return graph 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
184class GraphBuilder:
185 """A graph under construction by stream capture.
187 A graph groups a set of CUDA kernels and other CUDA operations together and executes
188 them with a specified dependency tree. It speeds up the workflow by combining the
189 driver activities associated with CUDA kernel launches and CUDA API calls.
191 Directly creating a :obj:`~graph.GraphBuilder` is not supported due
192 to ambiguity. New graph builders should instead be created through a
193 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object.
195 """
197 class _MembersNeededForFinalize:
198 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream")
200 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required):
201 self.stream = stream_obj 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
202 self.is_stream_owner = is_stream_owner 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
203 self.graph = None 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
204 self.conditional_graph = conditional_graph 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
205 self.is_join_required = is_join_required 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
206 weakref.finalize(graph_builder_obj, self.close) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
208 def close(self):
209 if self.stream: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
210 if not self.is_join_required: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
211 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
212 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
213 # Note how this condition only occures for the primary graph builder
214 # This is because calling cuStreamEndCapture streams that were split off of the primary
215 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED.
216 # Therefore, it is currently a requirement that users join all split graph builders
217 # before a graph builder can be clearly destroyed.
218 handle_return(driver.cuStreamEndCapture(self.stream.handle))
219 if self.is_stream_owner: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
220 self.stream.close() 1UCA#BOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
221 self.stream = None 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
222 if self.graph: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
223 handle_return(driver.cuGraphDestroy(self.graph)) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
224 self.graph = None 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
225 self.conditional_graph = None 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
227 __slots__ = ("__weakref__", "_building_ended", "_mnff")
229 def __init__(self):
230 raise NotImplementedError(
231 "directly creating a Graph object can be ambiguous. Please either "
232 "call Device.create_graph_builder() or stream.create_graph_builder()"
233 )
235 @classmethod
236 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False):
237 self = cls.__new__(cls) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
238 self._mnff = GraphBuilder._MembersNeededForFinalize( 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
239 self, stream, is_stream_owner, conditional_graph, is_join_required 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
240 )
242 self._building_ended = False 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
243 return self 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
245 @property
246 def stream(self) -> Stream:
247 """Returns the stream associated with the graph builder."""
248 return self._mnff.stream 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
250 @property
251 def is_join_required(self) -> bool:
252 """Returns True if this graph builder must be joined before building is ended."""
253 return self._mnff.is_join_required 1ABbcdefghijklmnopqrstua
255 def begin_building(self, mode="relaxed") -> GraphBuilder:
256 """Begins the building process.
258 Build `mode` for controlling interaction with other API calls must be one of the following:
260 - `global` : Prohibit potentially unsafe operations across all streams in the process.
261 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread.
262 - `relaxed` : The local thread is not prohibited from potentially unsafe operations.
264 Parameters
265 ----------
266 mode : str, optional
267 Build mode to control the interaction with other API calls that are porentially unsafe.
268 Default set to use relaxed.
270 """
271 if self._building_ended: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
272 raise RuntimeError("Cannot resume building after building has ended.") 1V
273 if mode not in ("global", "thread_local", "relaxed"): 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
274 raise ValueError(f"Unsupported build mode: {mode}") 1!
275 if mode == "global": 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
276 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 18923L45!
277 elif mode == "thread_local": 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz0617QWRXJKSYTZvHENM!Da
278 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 167WXKYZ!
279 elif mode == "relaxed": 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvHENM!Da
280 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvHENM!Da
281 else:
282 raise ValueError(f"Unsupported build mode: {mode}")
284 if self._mnff.conditional_graph: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
285 handle_return( 1bcdefghijklmnopqrstuwxyzva
286 driver.cuStreamBeginCaptureToGraph( 1bcdefghijklmnopqrstuwxyzva
287 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva
288 self._mnff.conditional_graph, 1bcdefghijklmnopqrstuwxyzva
289 None, # dependencies
290 None, # dependencyData
291 0, # numDependencies
292 capture_mode, 1bcdefghijklmnopqrstuwxyzva
293 )
294 )
295 else:
296 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
297 return self 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
299 @property
300 def is_building(self) -> bool:
301 """Returns True if the graph builder is currently building."""
302 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
303 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
304 return False 1#
305 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
306 return True 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
307 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED:
308 raise RuntimeError(
309 "Build process encountered an error and has been invalidated. Build process must now be ended."
310 )
311 else:
312 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}")
314 def end_building(self) -> GraphBuilder:
315 """Ends the building process."""
316 if not self.is_building: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
317 raise RuntimeError("Graph builder is not building.")
318 if self._mnff.conditional_graph: 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
319 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1bcdefghijklmnopqrstuwxyzva
320 else:
321 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
323 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to
324 # resume the build process after the first call to end_building()
325 self._building_ended = True 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
326 return self 1FGUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENM!Da
328 def complete(self, options: GraphCompleteOptions | None = None) -> "Graph":
329 """Completes the graph builder and returns the built :obj:`~graph.Graph` object.
331 Parameters
332 ----------
333 options : :obj:`~graph.GraphCompleteOptions`, optional
334 Customizable dataclass for the graph builder completion options.
336 Returns
337 -------
338 graph : :obj:`~graph.Graph`
339 The newly built graph.
341 """
342 if not self._building_ended: 1FGUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENMD
343 raise RuntimeError("Graph has not finished building.") 1U
345 return _instantiate_graph(self._mnff.graph, options) 1FGUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvHENMD
347 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
348 """Generates a DOT debug file for the graph builder.
350 Parameters
351 ----------
352 path : str
353 File path to use for writting debug DOT output
354 options : :obj:`~graph.GraphDebugPrintOptions`, optional
355 Customizable dataclass for the debug print options.
357 """
358 if not self._building_ended: 1a
359 raise RuntimeError("Graph has not finished building.")
360 flags = options._to_flags() if options else 0 1a
361 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) 1a
363 def split(self, count: int) -> tuple[GraphBuilder, ...]:
364 """Splits the original graph builder into multiple graph builders.
366 The new builders inherit work dependencies from the original builder.
367 The original builder is reused for the split and is returned first in the tuple.
369 Parameters
370 ----------
371 count : int
372 The number of graph builders to split the graph builder into.
374 Returns
375 -------
376 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
377 A tuple of split graph builders. The first graph builder in the tuple
378 is always the original graph builder.
380 """
381 if count < 2: 1ABbcdefghijklmnopqrstua
382 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1A
384 event = self._mnff.stream.record() 1ABbcdefghijklmnopqrstua
385 result = [self] 1ABbcdefghijklmnopqrstua
386 for i in range(count - 1): 1ABbcdefghijklmnopqrstua
387 stream = self._mnff.stream.device.create_stream() 1ABbcdefghijklmnopqrstua
388 stream.wait(event) 1ABbcdefghijklmnopqrstua
389 result.append( 1ABbcdefghijklmnopqrstua
390 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) 1ABbcdefghijklmnopqrstua
391 )
392 event.close() 1ABbcdefghijklmnopqrstua
393 return tuple(result) 1ABbcdefghijklmnopqrstua
395 @staticmethod
396 def join(*graph_builders) -> GraphBuilder:
397 """Joins multiple graph builders into a single graph builder.
399 The returned builder inherits work dependencies from the provided builders.
401 Parameters
402 ----------
403 *graph_builders : :obj:`~graph.GraphBuilder`
404 The graph builders to join.
406 Returns
407 -------
408 graph_builder : :obj:`~graph.GraphBuilder`
409 The newly joined graph builder.
411 """
412 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1ABbcdefghijklmnopqrstua
413 raise TypeError("All arguments must be GraphBuilder instances")
414 if len(graph_builders) < 2: 1ABbcdefghijklmnopqrstua
415 raise ValueError("Must join with at least two graph builders") 1A
417 # Discover the root builder others should join
418 root_idx = 0 1ABbcdefghijklmnopqrstua
419 for i, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua
420 if not builder.is_join_required: 1ABbcdefghijklmnopqrstua
421 root_idx = i 1ABbcdefghijklmnopqrstua
422 break 1ABbcdefghijklmnopqrstua
424 # Join all onto the root builder
425 root_bdr = graph_builders[root_idx] 1ABbcdefghijklmnopqrstua
426 for idx, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua
427 if idx == root_idx: 1ABbcdefghijklmnopqrstua
428 continue 1ABbcdefghijklmnopqrstua
429 root_bdr.stream.wait(builder.stream) 1ABbcdefghijklmnopqrstua
430 builder.close() 1ABbcdefghijklmnopqrstua
432 return root_bdr 1ABbcdefghijklmnopqrstua
434 def __cuda_stream__(self) -> tuple[int, int]:
435 """Return an instance of a __cuda_stream__ protocol."""
436 return self.stream.__cuda_stream__()
438 def _get_conditional_context(self) -> driver.CUcontext:
439 return self._mnff.stream.context.handle 1bcdefghijklmnopqrstuwxyzva
441 def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle:
442 """Creates a conditional handle for the graph builder.
444 Parameters
445 ----------
446 default_value : int, optional
447 The default value to assign to the conditional handle.
449 Returns
450 -------
451 handle : driver.CUgraphConditionalHandle
452 The newly created conditional handle.
454 """
455 if cy_driver_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva
456 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles")
457 if cy_binding_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva
458 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles")
459 if default_value is not None: 1bcdefghijklmnopqrstuwxyzva
460 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wxyzv
461 else:
462 default_value = 0 1bcdefghijklmnopqrstua
463 flags = 0 1bcdefghijklmnopqrstua
465 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1bcdefghijklmnopqrstuwxyzva
466 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva
467 raise RuntimeError("Cannot create a conditional handle when graph is not being built")
469 return handle_return( 1bcdefghijklmnopqrstuwxyzva
470 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 1bcdefghijklmnopqrstuwxyzva
471 )
473 def _cond_with_params(self, node_params) -> tuple:
474 # Get current capture info to ensure we're in a valid state
475 status, _, graph, *deps_info, num_dependencies = handle_return( 1bcdefghijklmnopqrstuwxyzva
476 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) 1bcdefghijklmnopqrstuwxyzva
477 )
478 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva
479 raise RuntimeError("Cannot add conditional node when not actively capturing")
481 # Add the conditional node to the graph
482 deps_info_update = [ 1bcdefghijklmnopqrstuwxyzva
483 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 1bcdefghijklmnopqrstuwxyzva
484 ] + [None] * (len(deps_info) - 1) 1bcdefghijklmnopqrstuwxyzva
486 # Update the stream's capture dependencies
487 handle_return( 1bcdefghijklmnopqrstuwxyzva
488 driver.cuStreamUpdateCaptureDependencies( 1bcdefghijklmnopqrstuwxyzva
489 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva
490 *deps_info_update, # dependencies, edgeData 1bcdefghijklmnopqrstuwxyzva
491 1, # numDependencies
492 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1bcdefghijklmnopqrstuwxyzva
493 )
494 )
496 # Create new graph builders for each condition
497 return tuple( 1bcdefghijklmnopqrstuwxyzva
498 [ 1bcdefghijklmnopqrstuwxyzva
499 GraphBuilder._init( 1bcdefghijklmnopqrstuwxyzva
500 stream=self._mnff.stream.device.create_stream(), 1bcdefghijklmnopqrstuwxyzva
501 is_stream_owner=True,
502 conditional_graph=node_params.conditional.phGraph_out[i], 1bcdefghijklmnopqrstuwxyzva
503 is_join_required=False, 1bcdefghijklmnopqrstuwxyzva
504 )
505 for i in range(node_params.conditional.size) 1bcdefghijklmnopqrstuwxyzva
506 ]
507 )
509 def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder:
510 """Adds an if condition branch and returns a new graph builder for it.
512 The resulting if graph will only execute the branch if the conditional
513 handle evaluates to true at runtime.
515 The new builder inherits work dependencies from the original builder.
517 Parameters
518 ----------
519 handle : driver.CUgraphConditionalHandle
520 The handle to use for the if conditional.
522 Returns
523 -------
524 graph_builder : :obj:`~graph.GraphBuilder`
525 The newly created conditional graph builder.
527 """
528 if cy_driver_version() < (12, 3, 0): 1bcdefghia
529 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if")
530 if cy_binding_version() < (12, 3, 0): 1bcdefghia
531 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if")
532 node_params = driver.CUgraphNodeParams() 1bcdefghia
533 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia
534 node_params.conditional.handle = handle 1bcdefghia
535 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia
536 node_params.conditional.size = 1 1bcdefghia
537 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia
538 return self._cond_with_params(node_params)[0] 1bcdefghia
540 def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]:
541 """Adds an if-else condition branch and returns new graph builders for both branches.
543 The resulting if graph will execute the branch if the conditional handle
544 evaluates to true at runtime, otherwise the else branch will execute.
546 The new builders inherit work dependencies from the original builder.
548 Parameters
549 ----------
550 handle : driver.CUgraphConditionalHandle
551 The handle to use for the if-else conditional.
553 Returns
554 -------
555 graph_builders : tuple[:obj:`~graph.GraphBuilder`, :obj:`~graph.GraphBuilder`]
556 A tuple of two new graph builders, one for the if branch and one for the else branch.
558 """
559 if cy_driver_version() < (12, 8, 0): 1jklmnopq
560 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else")
561 if cy_binding_version() < (12, 8, 0): 1jklmnopq
562 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else")
563 node_params = driver.CUgraphNodeParams() 1jklmnopq
564 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq
565 node_params.conditional.handle = handle 1jklmnopq
566 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq
567 node_params.conditional.size = 2 1jklmnopq
568 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq
569 return self._cond_with_params(node_params) 1jklmnopq
571 def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]:
572 """Adds a switch condition branch and returns new graph builders for all cases.
574 The resulting switch graph will execute the branch that matches the
575 case index of the conditional handle at runtime. If no match is found, no branch
576 will be executed.
578 The new builders inherit work dependencies from the original builder.
580 Parameters
581 ----------
582 handle : driver.CUgraphConditionalHandle
583 The handle to use for the switch conditional.
584 count : int
585 The number of cases to add to the switch conditional.
587 Returns
588 -------
589 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
590 A tuple of new graph builders, one for each branch.
592 """
593 if cy_driver_version() < (12, 8, 0): 1rstuv
594 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch")
595 if cy_binding_version() < (12, 8, 0): 1rstuv
596 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch")
597 node_params = driver.CUgraphNodeParams() 1rstuv
598 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1rstuv
599 node_params.conditional.handle = handle 1rstuv
600 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1rstuv
601 node_params.conditional.size = count 1rstuv
602 node_params.conditional.ctx = self._get_conditional_context() 1rstuv
603 return self._cond_with_params(node_params) 1rstuv
605 def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder:
606 """Adds a while loop and returns a new graph builder for it.
608 The resulting while loop graph will execute the branch repeatedly at runtime
609 until the conditional handle evaluates to false.
611 The new builder inherits work dependencies from the original builder.
613 Parameters
614 ----------
615 handle : driver.CUgraphConditionalHandle
616 The handle to use for the while loop.
618 Returns
619 -------
620 graph_builder : :obj:`~graph.GraphBuilder`
621 The newly created while loop graph builder.
623 """
624 if cy_driver_version() < (12, 3, 0): 1wxyz
625 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop")
626 if cy_binding_version() < (12, 3, 0): 1wxyz
627 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop")
628 node_params = driver.CUgraphNodeParams() 1wxyz
629 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wxyz
630 node_params.conditional.handle = handle 1wxyz
631 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1wxyz
632 node_params.conditional.size = 1 1wxyz
633 node_params.conditional.ctx = self._get_conditional_context() 1wxyz
634 return self._cond_with_params(node_params)[0] 1wxyz
636 def close(self):
637 """Destroy the graph builder.
639 Closes the associated stream if we own it. Borrowed stream
640 object will instead have their references released.
642 """
643 self._mnff.close() 1ABPbcdefghijklmnopqrstua
645 def add_child(self, child_graph: GraphBuilder):
646 """Adds the child :obj:`~graph.GraphBuilder` builder into self.
648 The child graph builder will be added as a child node to the parent graph builder.
650 Parameters
651 ----------
652 child_graph : :obj:`~graph.GraphBuilder`
653 The child graph builder. Must have finished building.
654 """
655 if not child_graph._building_ended: 1C
656 raise ValueError("Child graph has not finished building.")
658 if not self.is_building: 1C
659 raise ValueError("Parent graph is not being built.")
661 stream_handle = self._mnff.stream.handle 1C
662 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1C
663 driver.cuStreamGetCaptureInfo(stream_handle) 1C
664 )
666 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159
667 # for rationale
668 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1C
669 deps_info_update = [ 1C
670 [ 1C
671 handle_return( 1C
672 driver.cuGraphAddChildGraphNode( 1C
673 graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph 1C
674 )
675 )
676 ]
677 ] + [None] * (len(deps_info_out) - 1) 1C
678 handle_return( 1C
679 driver.cuStreamUpdateCaptureDependencies( 1C
680 stream_handle, 1C
681 *deps_info_update, # dependencies, edgeData
682 1,
683 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1C
684 )
685 )
687 def callback(self, fn, *, user_data=None):
688 """Add a host callback to the graph during stream capture.
690 The callback runs on the host CPU when the graph reaches this point
691 in execution. Two modes are supported:
693 - **Python callable**: Pass any callable. The GIL is acquired
694 automatically. The callable must take no arguments; use closures
695 or ``functools.partial`` to bind state.
696 - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance.
697 The function receives a single ``void*`` argument (the
698 ``user_data``). The caller must keep the ctypes wrapper alive
699 for the lifetime of the graph.
701 .. warning::
703 Callbacks must not call CUDA API functions. Doing so may
704 deadlock or corrupt driver state.
706 Parameters
707 ----------
708 fn : callable or ctypes function pointer
709 The callback function.
710 user_data : int or bytes-like, optional
711 Only for ctypes function pointers. If ``int``, passed as a raw
712 pointer (caller manages lifetime). If bytes-like, the data is
713 copied and its lifetime is tied to the graph.
714 """
715 cdef Stream stream = <Stream>self._mnff.stream 1FG
716 cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) 1FG
717 cdef cydriver.CUstreamCaptureStatus capture_status
718 cdef cydriver.CUgraph c_graph = NULL 1FG
720 with nogil: 1FG
721 IF CUDA_CORE_BUILD_MAJOR >= 13:
722 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 1FG
723 c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL))
724 ELSE:
725 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo(
726 c_stream, &capture_status, NULL, &c_graph, NULL, NULL))
728 if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1FG
729 raise RuntimeError("Cannot add callback when graph is not being built")
731 cdef cydriver.CUhostFn c_fn
732 cdef void* c_user_data = NULL 1FG
733 _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) 1FG
735 with nogil: 1FG
736 HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) 1FG
739class Graph:
740 """An executable graph.
742 A graph groups a set of CUDA kernels and other CUDA operations together and executes
743 them with a specified dependency tree. It speeds up the workflow by combining the
744 driver activities associated with CUDA kernel launches and CUDA API calls.
746 Graphs must be built using a :obj:`~graph.GraphBuilder` object.
748 """
750 class _MembersNeededForFinalize:
751 __slots__ = "graph"
753 def __init__(self, graph_obj, graph):
754 self.graph = graph 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
755 weakref.finalize(graph_obj, self.close) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
757 def close(self):
758 if self.graph: 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
759 handle_return(driver.cuGraphExecDestroy(self.graph)) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
760 self.graph = None 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
762 __slots__ = ("__weakref__", "_mnff")
764 def __init__(self):
765 raise RuntimeError("directly constructing a Graph instance is not supported")
767 @classmethod
768 def _init(cls, graph):
769 self = cls.__new__(cls) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
770 self._mnff = Graph._MembersNeededForFinalize(self, graph) 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
771 return self 2F G U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? obpb@ qbrbsbdbabbbebfbgbcbhbibjbkblbmb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v H $ E N M D
773 def close(self):
774 """Destroy the graph."""
775 self._mnff.close() 1PD
777 @property
778 def handle(self) -> driver.CUgraphExec:
779 """Return the underlying ``CUgraphExec`` object.
781 .. caution::
783 This handle is a Python object. To get the memory address of the underlying C
784 handle, call ``int()`` on the returned object.
786 """
787 return self._mnff.graph
789 def update(self, source: "GraphBuilder | GraphDefinition") -> None:
790 """Update the graph using a new graph definition.
792 The topology of the provided source must be identical to this graph.
794 Parameters
795 ----------
796 source : :obj:`~graph.GraphBuilder` or :obj:`~graph.GraphDefinition`
797 The graph definition to update from. A GraphBuilder must have
798 finished building.
800 """
801 from cuda.core.graph import GraphDefinition 1vH$ENM
803 cdef cydriver.CUgraph cu_graph
804 cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph) 1vH$ENM
806 if isinstance(source, GraphBuilder): 1vH$ENM
807 if not source._building_ended: 1vHEN
808 raise ValueError("Graph has not finished building.") 1N
809 cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph) 1vHE
810 elif isinstance(source, GraphDefinition): 1$M
811 cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle) 1$
812 else:
813 raise TypeError( 1M
814 f"expected GraphBuilder or GraphDefinition, got {type(source).__name__}") 1M
816 cdef cydriver.CUgraphExecUpdateResultInfo result_info
817 cdef cydriver.CUresult err
818 with nogil: 1vH$E
819 err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) 1vH$E
820 if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: 1vH$E
821 reason = driver.CUgraphExecUpdateResult(result_info.result) 1E
822 msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" 1E
823 raise CUDAError(msg) 1E
824 HANDLE_RETURN(err) 1vH$
826 def upload(self, stream: Stream):
827 """Uploads the graph in a stream.
829 Parameters
830 ----------
831 stream : :obj:`~_stream.Stream`
832 The stream in which to upload the graph
834 """
835 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 2A O [ % ] ' ^ _ ` { ) | * , - } . : ~ ; abbbcb2 Q W 3 R X L J K 4 S Y 5 T Z
837 def launch(self, stream: Stream):
838 """Launches the graph in a stream.
840 Parameters
841 ----------
842 stream : :obj:`~_stream.Stream`
843 The stream in which to launch the graph
845 """
846 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 2F G C A V O b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = dbabbbebfbgbcbhbibjbkblbmb2 Q W 3 R X L J K 4 S Y 5 T Z v H $