Coverage for cuda/core/graph/_graph_builder.pyx: 88.17%
389 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5import weakref
6from dataclasses import dataclass
7from typing import TYPE_CHECKING
9from libc.stdint cimport intptr_t
11from cuda.bindings cimport cydriver
13from cuda.core.graph._graph_definition cimport GraphCondition
14from cuda.core.graph._utils cimport _attach_host_callback_to_graph
15from cuda.core._resource_handles cimport as_cu
16from cuda.core._stream cimport Stream
17from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
18from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
20from cuda.core._utils.cuda_utils import (
21 CUDAError,
22 driver,
23 handle_return,
24)
26if TYPE_CHECKING:
27 from cuda.core.graph._graph_definition import GraphDefinition
29__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions']
32@dataclass
33class GraphDebugPrintOptions:
34 """Options for debug_dot_print().
36 Attributes
37 ----------
38 verbose : bool
39 Output all debug data as if every debug flag is enabled (Default to False)
40 runtime_types : bool
41 Use CUDA Runtime structures for output (Default to False)
42 kernel_node_params : bool
43 Adds kernel parameter values to output (Default to False)
44 memcpy_node_params : bool
45 Adds memcpy parameter values to output (Default to False)
46 memset_node_params : bool
47 Adds memset parameter values to output (Default to False)
48 host_node_params : bool
49 Adds host parameter values to output (Default to False)
50 event_node_params : bool
51 Adds event parameter values to output (Default to False)
52 ext_semas_signal_node_params : bool
53 Adds external semaphore signal parameter values to output (Default to False)
54 ext_semas_wait_node_params : bool
55 Adds external semaphore wait parameter values to output (Default to False)
56 kernel_node_attributes : bool
57 Adds kernel node attributes to output (Default to False)
58 handles : bool
59 Adds node handles and every kernel function handle to output (Default to False)
60 mem_alloc_node_params : bool
61 Adds memory alloc parameter values to output (Default to False)
62 mem_free_node_params : bool
63 Adds memory free parameter values to output (Default to False)
64 batch_mem_op_node_params : bool
65 Adds batch mem op parameter values to output (Default to False)
66 extra_topo_info : bool
67 Adds edge numbering information (Default to False)
68 conditional_node_params : bool
69 Adds conditional node parameter values to output (Default to False)
71 """
73 verbose: bool = False
74 runtime_types: bool = False
75 kernel_node_params: bool = False
76 memcpy_node_params: bool = False
77 memset_node_params: bool = False
78 host_node_params: bool = False
79 event_node_params: bool = False
80 ext_semas_signal_node_params: bool = False
81 ext_semas_wait_node_params: bool = False
82 kernel_node_attributes: bool = False
83 handles: bool = False
84 mem_alloc_node_params: bool = False
85 mem_free_node_params: bool = False
86 batch_mem_op_node_params: bool = False
87 extra_topo_info: bool = False
88 conditional_node_params: bool = False
90 def _to_flags(self) -> int:
91 """Convert options to CUDA driver API flags (internal use)."""
92 flags = 0 2oba
93 if self.verbose: 2oba
94 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 2oba
95 if self.runtime_types: 2oba
96 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a
97 if self.kernel_node_params: 2oba
98 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a
99 if self.memcpy_node_params: 2oba
100 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a
101 if self.memset_node_params: 2oba
102 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a
103 if self.host_node_params: 2oba
104 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a
105 if self.event_node_params: 2oba
106 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a
107 if self.ext_semas_signal_node_params: 2oba
108 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a
109 if self.ext_semas_wait_node_params: 2oba
110 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a
111 if self.kernel_node_attributes: 2oba
112 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a
113 if self.handles: 2oba
114 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 2oba
115 if self.mem_alloc_node_params: 2D oba
116 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a
117 if self.mem_free_node_params: 2oba
118 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a
119 if self.batch_mem_op_node_params: 2oba
120 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a
121 if self.extra_topo_info: 2oba
122 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a
123 if self.conditional_node_params: 2oba
124 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a
125 return flags 2oba
128@dataclass
129class GraphCompleteOptions:
130 """Options for graph instantiation.
132 Attributes
133 ----------
134 auto_free_on_launch : bool, optional
135 Automatically free memory allocated in a graph before relaunching. (Default to False)
136 upload_stream : Stream, optional
137 Stream to use to automatically upload the graph after completion. (Default to None)
138 device_launch : bool, optional
139 Configure the graph to be launchable from the device. This flag can only
140 be used on platforms which support unified addressing. This flag cannot be
141 used in conjunction with auto_free_on_launch. (Default to False)
142 use_node_priority : bool, optional
143 Run the graph using the per-node priority attributes rather than the
144 priority of the stream it is launched into. (Default to False)
146 """
148 auto_free_on_launch: bool = False
149 upload_stream: Stream | None = None
150 device_launch: bool = False
151 use_node_priority: bool = False
154def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> "Graph":
155 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
156 if options: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
157 flags = 0 1%'()*+,-./:;=?@LJKE
158 if options.auto_free_on_launch: 1%'()*+,-./:;=?@LJKE
159 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1%)-:?@LJKE
160 if options.upload_stream: 1%'()*+,-./:;=?@LJKE
161 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1(+/=E
162 params.hUploadStream = options.upload_stream.handle 1(+/=E
163 if options.device_launch: 1%'()*+,-./:;=?@LJKE
164 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1,E
165 if options.use_node_priority: 1%'()*+,-./:;=?@LJKE
166 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 1'*.;?@E
167 params.flags = flags 1%'()*+,-./:;=?@LJKE
169 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
170 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
171 raise RuntimeError(
172 "Instantiation failed for an unexpected reason which is described in the return value of the function."
173 )
174 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
175 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
176 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
177 raise RuntimeError(
178 "Instantiation for device launch failed because the graph contained an unsupported operation."
179 )
180 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
181 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
182 elif ( 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
183 cy_binding_version() >= (12, 8, 0) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
184 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
185 ):
186 raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
187 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
188 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
189 return graph 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
192class GraphBuilder:
193 """A graph under construction by stream capture.
195 A graph groups a set of CUDA kernels and other CUDA operations together and executes
196 them with a specified dependency tree. It speeds up the workflow by combining the
197 driver activities associated with CUDA kernel launches and CUDA API calls.
199 Directly creating a :obj:`~graph.GraphBuilder` is not supported due
200 to ambiguity. New graph builders should instead be created through a
201 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object.
203 """
205 class _MembersNeededForFinalize:
206 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream")
208 def __init__(self, graph_builder_obj: GraphBuilder, stream_obj: Stream | None, is_stream_owner: bool, conditional_graph, is_join_required: bool) -> None:
209 self.stream = stream_obj 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
210 self.is_stream_owner = is_stream_owner 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
211 self.graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
212 self.conditional_graph = conditional_graph 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
213 self.is_join_required = is_join_required 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
214 weakref.finalize(graph_builder_obj, self.close) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
216 def close(self) -> None:
217 if self.stream: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
218 if not self.is_join_required: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
219 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
220 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
221 # Note how this condition only occures for the primary graph builder
222 # This is because calling cuStreamEndCapture streams that were split off of the primary
223 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED.
224 # Therefore, it is currently a requirement that users join all split graph builders
225 # before a graph builder can be clearly destroyed.
226 handle_return(driver.cuStreamEndCapture(self.stream.handle))
227 if self.is_stream_owner: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
228 self.stream.close() 1UCA#BOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
229 self.stream = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
230 if self.graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
231 handle_return(driver.cuGraphDestroy(self.graph)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
232 self.graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
233 self.conditional_graph = None 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
235 __slots__ = ("__weakref__", "_building_ended", "_mnff")
237 def __init__(self) -> None:
238 raise NotImplementedError(
239 "directly creating a Graph object can be ambiguous. Please either "
240 "call Device.create_graph_builder() or stream.create_graph_builder()"
241 )
243 @classmethod
244 def _init(cls, stream: Stream | None, is_stream_owner: bool, conditional_graph: object = None, is_join_required: bool = False) -> GraphBuilder:
245 self = cls.__new__(cls) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
246 self._mnff = GraphBuilder._MembersNeededForFinalize( 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
247 self, stream, is_stream_owner, conditional_graph, is_join_required 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
248 )
250 self._building_ended = False 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
251 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
253 @property
254 def stream(self) -> Stream:
255 """Returns the stream associated with the graph builder."""
256 return self._mnff.stream 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
258 @property
259 def is_join_required(self) -> bool:
260 """Returns True if this graph builder must be joined before building is ended."""
261 return self._mnff.is_join_required 1ABbcdefghijklmnopqrstua
263 def begin_building(self, mode: str | None = "relaxed") -> GraphBuilder:
264 """Begins the building process.
266 Build `mode` for controlling interaction with other API calls must be one of the following:
268 - `global` : Prohibit potentially unsafe operations across all streams in the process.
269 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread.
270 - `relaxed` : The local thread is not prohibited from potentially unsafe operations.
272 Parameters
273 ----------
274 mode : str, optional
275 Build mode to control the interaction with other API calls that are porentially unsafe.
276 Default set to use relaxed.
278 """
279 if self._building_ended: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
280 raise RuntimeError("Cannot resume building after building has ended.") 1V
281 if mode not in ("global", "thread_local", "relaxed"): 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
282 raise ValueError(f"Unsupported build mode: {mode}") 1!
283 if mode == "global": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
284 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 18923L45!
285 elif mode == "thread_local": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz0617QWRXJKSYTZvIFNM!Ea
286 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 167WXKYZ!
287 elif mode == "relaxed": 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvIFNM!Ea
288 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz01QRJSTvIFNM!Ea
289 else:
290 raise ValueError(f"Unsupported build mode: {mode}")
292 if self._mnff.conditional_graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
293 handle_return( 1bcdefghijklmnopqrstuwxyzva
294 driver.cuStreamBeginCaptureToGraph( 1bcdefghijklmnopqrstuwxyzva
295 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva
296 self._mnff.conditional_graph, 1bcdefghijklmnopqrstuwxyzva
297 None, # dependencies
298 None, # dependencyData
299 0, # numDependencies
300 capture_mode, 1bcdefghijklmnopqrstuwxyzva
301 )
302 )
303 else:
304 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
305 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
307 @property
308 def is_building(self) -> bool:
309 """Returns True if the graph builder is currently building."""
310 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
311 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
312 return False 1#
313 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
314 return True 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
315 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED:
316 raise RuntimeError(
317 "Build process encountered an error and has been invalidated. Build process must now be ended."
318 )
319 else:
320 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}")
322 def end_building(self) -> GraphBuilder:
323 """Ends the building process."""
324 if not self.is_building: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
325 raise RuntimeError("Graph builder is not building.")
326 if self._mnff.conditional_graph: 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
327 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1bcdefghijklmnopqrstuwxyzva
328 else:
329 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
331 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to
332 # resume the build process after the first call to end_building()
333 self._building_ended = True 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
334 return self 1GHUCA#BVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNM!Ea
336 def complete(self, options: GraphCompleteOptions | None = None) -> "Graph":
337 """Completes the graph builder and returns the built :obj:`~graph.Graph` object.
339 Parameters
340 ----------
341 options : :obj:`~graph.GraphCompleteOptions`, optional
342 Customizable dataclass for the graph builder completion options.
344 Returns
345 -------
346 graph : :obj:`~graph.Graph`
347 The newly built graph.
349 """
350 if not self._building_ended: 1GHUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNME
351 raise RuntimeError("Graph has not finished building.") 1U
353 return _instantiate_graph(self._mnff.graph, options) 1GHUCABVOPbcdefghijklmnopqrstuwxyz8069172QW3RXLJK4SY5TZvIFNME
355 def debug_dot_print(self, path: str, options: GraphDebugPrintOptions | None = None) -> None:
356 """Generates a DOT debug file for the graph builder.
358 Parameters
359 ----------
360 path : str
361 File path to use for writting debug DOT output
362 options : :obj:`~graph.GraphDebugPrintOptions`, optional
363 Customizable dataclass for the debug print options.
365 """
366 if not self._building_ended: 1a
367 raise RuntimeError("Graph has not finished building.")
368 flags = options._to_flags() if options else 0 1a
369 cdef bytes path_bytes = path.encode('utf-8') 1a
370 cdef const char* c_path = path_bytes 1a
371 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, c_path, flags)) 1a
373 def split(self, count: int) -> tuple[GraphBuilder, ...]:
374 """Splits the original graph builder into multiple graph builders.
376 The new builders inherit work dependencies from the original builder.
377 The original builder is reused for the split and is returned first in the tuple.
379 Parameters
380 ----------
381 count : int
382 The number of graph builders to split the graph builder into.
384 Returns
385 -------
386 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
387 A tuple of split graph builders. The first graph builder in the tuple
388 is always the original graph builder.
390 """
391 if count < 2: 1ABbcdefghijklmnopqrstua
392 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1A
394 event = self._mnff.stream.record() 1ABbcdefghijklmnopqrstua
395 result = [self] 1ABbcdefghijklmnopqrstua
396 for i in range(count - 1): 1ABbcdefghijklmnopqrstua
397 stream = self._mnff.stream.device.create_stream() 1ABbcdefghijklmnopqrstua
398 stream.wait(event) 1ABbcdefghijklmnopqrstua
399 result.append( 1ABbcdefghijklmnopqrstua
400 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) 1ABbcdefghijklmnopqrstua
401 )
402 event.close() 1ABbcdefghijklmnopqrstua
403 return tuple(result) 1ABbcdefghijklmnopqrstua
405 @staticmethod
406 def join(*graph_builders: GraphBuilder) -> GraphBuilder:
407 """Joins multiple graph builders into a single graph builder.
409 The returned builder inherits work dependencies from the provided builders.
411 Parameters
412 ----------
413 *graph_builders : :obj:`~graph.GraphBuilder`
414 The graph builders to join.
416 Returns
417 -------
418 graph_builder : :obj:`~graph.GraphBuilder`
419 The newly joined graph builder.
421 """
422 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1ABbcdefghijklmnopqrstua
423 raise TypeError("All arguments must be GraphBuilder instances")
424 if len(graph_builders) < 2: 1ABbcdefghijklmnopqrstua
425 raise ValueError("Must join with at least two graph builders") 1A
427 # Discover the root builder others should join
428 root_idx = 0 1ABbcdefghijklmnopqrstua
429 for i, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua
430 if not builder.is_join_required: 1ABbcdefghijklmnopqrstua
431 root_idx = i 1ABbcdefghijklmnopqrstua
432 break 1ABbcdefghijklmnopqrstua
434 # Join all onto the root builder
435 root_bdr = graph_builders[root_idx] 1ABbcdefghijklmnopqrstua
436 for idx, builder in enumerate(graph_builders): 1ABbcdefghijklmnopqrstua
437 if idx == root_idx: 1ABbcdefghijklmnopqrstua
438 continue 1ABbcdefghijklmnopqrstua
439 root_bdr.stream.wait(builder.stream) 1ABbcdefghijklmnopqrstua
440 builder.close() 1ABbcdefghijklmnopqrstua
442 return root_bdr 1ABbcdefghijklmnopqrstua
444 def __cuda_stream__(self) -> tuple[int, int]:
445 """Return an instance of a __cuda_stream__ protocol."""
446 return self.stream.__cuda_stream__()
448 def _get_conditional_context(self) -> driver.CUcontext:
449 return self._mnff.stream.context.handle 1bcdefghijklmnopqrstuwxyzva
451 def create_condition(self, default_value: int | None = None) -> GraphCondition:
452 """Create a condition variable for use with conditional nodes.
454 The returned :class:`GraphCondition` object is passed to conditional-node
455 builder methods (:meth:`if_then`, :meth:`if_else`, :meth:`while_loop`,
456 :meth:`switch`). Its value is controlled at runtime by device code via
457 ``cudaGraphSetConditional``.
459 Parameters
460 ----------
461 default_value : int, optional
462 The default value to assign to the condition. If None, no
463 default is assigned.
465 Returns
466 -------
467 GraphCondition
468 A condition variable for controlling conditional execution.
469 """
470 if cy_driver_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva
471 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles")
472 if cy_binding_version() < (12, 3, 0): 1bcdefghijklmnopqrstuwxyzva
473 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles")
474 if default_value is not None: 1bcdefghijklmnopqrstuwxyzva
475 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wxyzv
476 else:
477 default_value = 0 1bcdefghijklmnopqrstua
478 flags = 0 1bcdefghijklmnopqrstua
480 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1bcdefghijklmnopqrstuwxyzva
481 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva
482 raise RuntimeError("Cannot create a condition when graph is not being built")
484 raw_handle = handle_return( 1bcdefghijklmnopqrstuwxyzva
485 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 1bcdefghijklmnopqrstuwxyzva
486 )
487 return GraphCondition._from_handle(<cydriver.CUgraphConditionalHandle><intptr_t>int(raw_handle)) 1bcdefghijklmnopqrstuwxyzva
489 def _cond_with_params(self, node_params: object) -> tuple[GraphBuilder, ...]:
490 # Get current capture info to ensure we're in a valid state
491 status, _, graph, *deps_info, num_dependencies = handle_return( 1bcdefghijklmnopqrstuwxyzva
492 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) 1bcdefghijklmnopqrstuwxyzva
493 )
494 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqrstuwxyzva
495 raise RuntimeError("Cannot add conditional node when not actively capturing")
497 # Add the conditional node to the graph
498 deps_info_update = [ 1bcdefghijklmnopqrstuwxyzva
499 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 1bcdefghijklmnopqrstuwxyzva
500 ] + [None] * (len(deps_info) - 1) 1bcdefghijklmnopqrstuwxyzva
502 # Update the stream's capture dependencies
503 handle_return( 1bcdefghijklmnopqrstuwxyzva
504 driver.cuStreamUpdateCaptureDependencies( 1bcdefghijklmnopqrstuwxyzva
505 self._mnff.stream.handle, 1bcdefghijklmnopqrstuwxyzva
506 *deps_info_update, # dependencies, edgeData 1bcdefghijklmnopqrstuwxyzva
507 1, # numDependencies
508 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1bcdefghijklmnopqrstuwxyzva
509 )
510 )
512 # Create new graph builders for each condition
513 return tuple( 1bcdefghijklmnopqrstuwxyzva
514 [ 1bcdefghijklmnopqrstuwxyzva
515 GraphBuilder._init( 1bcdefghijklmnopqrstuwxyzva
516 stream=self._mnff.stream.device.create_stream(), 1bcdefghijklmnopqrstuwxyzva
517 is_stream_owner=True,
518 conditional_graph=node_params.conditional.phGraph_out[i], 1bcdefghijklmnopqrstuwxyzva
519 is_join_required=False, 1bcdefghijklmnopqrstuwxyzva
520 )
521 for i in range(node_params.conditional.size) 1bcdefghijklmnopqrstuwxyzva
522 ]
523 )
525 def if_then(self, condition: GraphCondition) -> GraphBuilder:
526 """Adds an if condition branch and returns a new graph builder for it.
528 The resulting if graph will only execute the branch if the
529 condition evaluates to true at runtime.
531 The new builder inherits work dependencies from the original builder.
533 Parameters
534 ----------
535 condition : :class:`~graph.GraphCondition`
536 The condition variable from :meth:`create_condition` controlling
537 whether the branch executes.
539 Returns
540 -------
541 graph_builder : :obj:`~graph.GraphBuilder`
542 The newly created conditional graph builder.
544 """
545 if cy_driver_version() < (12, 3, 0): 1bcdefghia
546 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if")
547 if cy_binding_version() < (12, 3, 0): 1bcdefghia
548 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if")
549 if not isinstance(condition, GraphCondition): 1bcdefghia
550 raise TypeError(
551 f"condition must be a GraphCondition object (from "
552 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
553 node_params = driver.CUgraphNodeParams() 1bcdefghia
554 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia
555 node_params.conditional.handle = condition.handle 1bcdefghia
556 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia
557 node_params.conditional.size = 1 1bcdefghia
558 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia
559 return self._cond_with_params(node_params)[0] 1bcdefghia
561 def if_else(self, condition: GraphCondition) -> tuple[GraphBuilder, GraphBuilder]:
562 """Adds an if-else condition branch and returns new graph builders for both branches.
564 The resulting if graph will execute the branch if the condition
565 evaluates to true at runtime, otherwise the else branch will execute.
567 The new builders inherit work dependencies from the original builder.
569 Parameters
570 ----------
571 condition : :class:`~graph.GraphCondition`
572 The condition variable from :meth:`create_condition` controlling
573 which branch executes.
575 Returns
576 -------
577 graph_builders : tuple[:obj:`~graph.GraphBuilder`, :obj:`~graph.GraphBuilder`]
578 A tuple of two new graph builders, one for the if branch and one for the else branch.
580 """
581 if cy_driver_version() < (12, 8, 0): 1jklmnopq
582 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else")
583 if cy_binding_version() < (12, 8, 0): 1jklmnopq
584 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else")
585 if not isinstance(condition, GraphCondition): 1jklmnopq
586 raise TypeError(
587 f"condition must be a GraphCondition object (from "
588 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
589 node_params = driver.CUgraphNodeParams() 1jklmnopq
590 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq
591 node_params.conditional.handle = condition.handle 1jklmnopq
592 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq
593 node_params.conditional.size = 2 1jklmnopq
594 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq
595 return self._cond_with_params(node_params) 1jklmnopq
597 def switch(self, condition: GraphCondition, count: int) -> tuple[GraphBuilder, ...]:
598 """Adds a switch condition branch and returns new graph builders for all cases.
600 The resulting switch graph will execute the branch whose case index
601 matches the value of the condition at runtime. If no match is found, no
602 branch will be executed.
604 The new builders inherit work dependencies from the original builder.
606 Parameters
607 ----------
608 condition : :class:`~graph.GraphCondition`
609 The condition variable from :meth:`create_condition` selecting
610 which case executes.
611 count : int
612 The number of cases to add to the switch conditional.
614 Returns
615 -------
616 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
617 A tuple of new graph builders, one for each branch.
619 """
620 if cy_driver_version() < (12, 8, 0): 1rstuv
621 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch")
622 if cy_binding_version() < (12, 8, 0): 1rstuv
623 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch")
624 if not isinstance(condition, GraphCondition): 1rstuv
625 raise TypeError(
626 f"condition must be a GraphCondition object (from "
627 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
628 node_params = driver.CUgraphNodeParams() 1rstuv
629 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1rstuv
630 node_params.conditional.handle = condition.handle 1rstuv
631 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1rstuv
632 node_params.conditional.size = count 1rstuv
633 node_params.conditional.ctx = self._get_conditional_context() 1rstuv
634 return self._cond_with_params(node_params) 1rstuv
636 def while_loop(self, condition: GraphCondition) -> GraphBuilder:
637 """Adds a while loop and returns a new graph builder for it.
639 The resulting while loop graph will execute the branch repeatedly at runtime
640 until the condition evaluates to false.
642 The new builder inherits work dependencies from the original builder.
644 Parameters
645 ----------
646 condition : :class:`~graph.GraphCondition`
647 The condition variable from :meth:`create_condition` controlling
648 loop continuation.
650 Returns
651 -------
652 graph_builder : :obj:`~graph.GraphBuilder`
653 The newly created while loop graph builder.
655 """
656 if cy_driver_version() < (12, 3, 0): 1wxyz
657 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop")
658 if cy_binding_version() < (12, 3, 0): 1wxyz
659 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop")
660 if not isinstance(condition, GraphCondition): 1wxyz
661 raise TypeError(
662 f"condition must be a GraphCondition object (from "
663 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
664 node_params = driver.CUgraphNodeParams() 1wxyz
665 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wxyz
666 node_params.conditional.handle = condition.handle 1wxyz
667 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1wxyz
668 node_params.conditional.size = 1 1wxyz
669 node_params.conditional.ctx = self._get_conditional_context() 1wxyz
670 return self._cond_with_params(node_params)[0] 1wxyz
672 def close(self) -> None:
673 """Destroy the graph builder.
675 Closes the associated stream if we own it. Borrowed stream
676 object will instead have their references released.
678 """
679 self._mnff.close() 1ABPbcdefghijklmnopqrstua
681 def embed(self, child: GraphBuilder) -> None:
682 """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node.
684 Parameters
685 ----------
686 child : :obj:`~graph.GraphBuilder`
687 The child graph builder. Must have finished building.
688 """
689 if not child._building_ended: 1C
690 raise ValueError("Child graph has not finished building.")
692 if not self.is_building: 1C
693 raise ValueError("Parent graph is not being built.")
695 stream_handle = self._mnff.stream.handle 1C
696 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1C
697 driver.cuStreamGetCaptureInfo(stream_handle) 1C
698 )
700 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159
701 # for rationale
702 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1C
703 deps_info_update = [ 1C
704 [ 1C
705 handle_return( 1C
706 driver.cuGraphAddChildGraphNode( 1C
707 graph_out, *deps_info_trimmed, num_dependencies_out, child._mnff.graph 1C
708 )
709 )
710 ]
711 ] + [None] * (len(deps_info_out) - 1) 1C
712 handle_return( 1C
713 driver.cuStreamUpdateCaptureDependencies( 1C
714 stream_handle, 1C
715 *deps_info_update, # dependencies, edgeData
716 1,
717 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1C
718 )
719 )
721 def callback(self, fn, *, user_data=None) -> None:
722 """Add a host callback to the graph during stream capture.
724 The callback runs on the host CPU when the graph reaches this point
725 in execution. Two modes are supported:
727 - **Python callable**: Pass any callable. The GIL is acquired
728 automatically. The callable must take no arguments; use closures
729 or ``functools.partial`` to bind state.
730 - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance.
731 The function receives a single ``void*`` argument (the
732 ``user_data``). The caller must keep the ctypes wrapper alive
733 for the lifetime of the graph.
735 .. warning::
737 Callbacks must not call CUDA API functions. Doing so may
738 deadlock or corrupt driver state.
740 Parameters
741 ----------
742 fn : callable or ctypes function pointer
743 The callback function.
744 user_data : int or bytes-like, optional
745 Only for ctypes function pointers. If ``int``, passed as a raw
746 pointer (caller manages lifetime). If bytes-like, the data is
747 copied and its lifetime is tied to the graph.
748 """
749 cdef Stream stream = <Stream>self._mnff.stream 1GH
750 cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) 1GH
751 cdef cydriver.CUstreamCaptureStatus capture_status
752 cdef cydriver.CUgraph c_graph = NULL 1GH
754 with nogil: 1GH
755 IF CUDA_CORE_BUILD_MAJOR >= 13:
756 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 1GH
757 c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL))
758 ELSE:
759 HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo(
760 c_stream, &capture_status, NULL, &c_graph, NULL, NULL))
762 if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GH
763 raise RuntimeError("Cannot add callback when graph is not being built")
765 cdef cydriver.CUhostFn c_fn
766 cdef void* c_user_data = NULL 1GH
767 _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) 1GH
769 with nogil: 1GH
770 HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) 1GH
773class Graph:
774 """An executable graph.
776 A graph groups a set of CUDA kernels and other CUDA operations together and executes
777 them with a specified dependency tree. It speeds up the workflow by combining the
778 driver activities associated with CUDA kernel launches and CUDA API calls.
780 Graphs must be built using a :obj:`~graph.GraphBuilder` object.
782 """
784 class _MembersNeededForFinalize:
785 __slots__ = "graph"
787 def __init__(self, graph_obj: Graph, graph: driver.CUgraphExec) -> None:
788 self.graph = graph 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
789 weakref.finalize(graph_obj, self.close) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
791 def close(self) -> None:
792 if self.graph: 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
793 handle_return(driver.cuGraphExecDestroy(self.graph)) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
794 self.graph = None 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
796 __slots__ = ("__weakref__", "_mnff")
798 def __init__(self) -> None:
799 raise RuntimeError("directly constructing a Graph instance is not supported")
801 @classmethod
802 def _init(cls, graph: driver.CUgraphExec) -> Graph:
803 self = cls.__new__(cls) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
804 self._mnff = Graph._MembersNeededForFinalize(self, graph) 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
805 return self 2G H U C A B V O P b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = ? pbqb@ rbsbtbdbabbbebfbgbhbcbibjbkblbmbnb8 0 6 9 1 7 2 Q W 3 R X L J K 4 S Y 5 T Z v I $ F N M E
807 def close(self) -> None:
808 """Destroy the graph."""
809 self._mnff.close() 1PE
811 @property
812 def handle(self) -> driver.CUgraphExec:
813 """Return the underlying ``CUgraphExec`` object.
815 .. caution::
817 This handle is a Python object. To get the memory address of the underlying C
818 handle, call ``int()`` on the returned object.
820 """
821 return self._mnff.graph
823 def update(self, source: "GraphBuilder | GraphDefinition") -> None:
824 """Update the graph using a new graph definition.
826 The topology of the provided source must be identical to this graph.
828 Parameters
829 ----------
830 source : :obj:`~graph.GraphBuilder` or :obj:`~graph.GraphDefinition`
831 The graph definition to update from. A GraphBuilder must have
832 finished building.
834 """
835 from cuda.core.graph import GraphDefinition 1vI$FNM
837 cdef cydriver.CUgraph cu_graph
838 cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph) 1vI$FNM
840 if isinstance(source, GraphBuilder): 1vI$FNM
841 if not source._building_ended: 1vIFN
842 raise ValueError("Graph has not finished building.") 1N
843 cu_graph = <cydriver.CUgraph><intptr_t>int(source._mnff.graph) 1vIF
844 elif isinstance(source, GraphDefinition): 1$M
845 cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle) 1$
846 else:
847 raise TypeError( 1M
848 f"expected GraphBuilder or GraphDefinition, got {type(source).__name__}") 1M
850 cdef cydriver.CUgraphExecUpdateResultInfo result_info
851 cdef cydriver.CUresult err
852 with nogil: 1vI$F
853 err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) 1vI$F
854 if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: 1vI$F
855 reason = driver.CUgraphExecUpdateResult(result_info.result) 1F
856 msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" 1F
857 raise CUDAError(msg) 1F
858 HANDLE_RETURN(err) 1vI$
860 def upload(self, stream: Stream) -> None:
861 """Uploads the graph in a stream.
863 Parameters
864 ----------
865 stream : :obj:`~_stream.Stream`
866 The stream in which to upload the graph
868 """
869 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 2A O [ % ] ' ^ _ ` { ) | * , - } . : ~ ; abbbcb2 Q W 3 R X L J K 4 S Y 5 T Z
871 def launch(self, stream: Stream) -> None:
872 """Launches the graph in a stream.
874 Parameters
875 ----------
876 stream : :obj:`~_stream.Stream`
877 The stream in which to launch the graph.
879 """
880 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 2G H C A V O b c d e f g h i j k l m n o p q r s t u w x y z [ % ] ' ( ^ _ ` { ) | * + , - } . / : ~ ; = dbabbbebfbgbhbcbibjbkblbmbnb2 Q W 3 R X L J K 4 S Y 5 T Z v I $