Coverage for cuda/core/graph/_graph_builder.pyx: 89.32%
412 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-03 01:38 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-03 01:38 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from dataclasses import dataclass
6from typing import TYPE_CHECKING
8from libc.stdint cimport intptr_t
10from cuda.bindings cimport cydriver
12from cuda.core.graph._graph_definition cimport GraphCondition
13from cuda.core.graph._utils cimport _attach_host_callback_to_graph
14from cuda.core._resource_handles cimport (
15 GraphHandle,
16 as_cu, as_py,
17 create_graph_exec_handle, create_graph_handle, create_graph_handle_ref,
18)
19from cuda.core._stream cimport Stream
20from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
21from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
23from cuda.core._utils.cuda_utils import (
24 CUDAError,
25 driver,
26 handle_return,
27)
29if TYPE_CHECKING:
30 from cuda.core.graph._graph_definition import GraphDefinition
32__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions']
35@dataclass
36class GraphDebugPrintOptions:
37 """Options for debug_dot_print().
39 Attributes
40 ----------
41 verbose : bool
42 Output all debug data as if every debug flag is enabled (Default to False)
43 runtime_types : bool
44 Use CUDA Runtime structures for output (Default to False)
45 kernel_node_params : bool
46 Adds kernel parameter values to output (Default to False)
47 memcpy_node_params : bool
48 Adds memcpy parameter values to output (Default to False)
49 memset_node_params : bool
50 Adds memset parameter values to output (Default to False)
51 host_node_params : bool
52 Adds host parameter values to output (Default to False)
53 event_node_params : bool
54 Adds event parameter values to output (Default to False)
55 ext_semas_signal_node_params : bool
56 Adds external semaphore signal parameter values to output (Default to False)
57 ext_semas_wait_node_params : bool
58 Adds external semaphore wait parameter values to output (Default to False)
59 kernel_node_attributes : bool
60 Adds kernel node attributes to output (Default to False)
61 handles : bool
62 Adds node handles and every kernel function handle to output (Default to False)
63 mem_alloc_node_params : bool
64 Adds memory alloc parameter values to output (Default to False)
65 mem_free_node_params : bool
66 Adds memory free parameter values to output (Default to False)
67 batch_mem_op_node_params : bool
68 Adds batch mem op parameter values to output (Default to False)
69 extra_topo_info : bool
70 Adds edge numbering information (Default to False)
71 conditional_node_params : bool
72 Adds conditional node parameter values to output (Default to False)
74 """
76 verbose: bool = False
77 runtime_types: bool = False
78 kernel_node_params: bool = False
79 memcpy_node_params: bool = False
80 memset_node_params: bool = False
81 host_node_params: bool = False
82 event_node_params: bool = False
83 ext_semas_signal_node_params: bool = False
84 ext_semas_wait_node_params: bool = False
85 kernel_node_attributes: bool = False
86 handles: bool = False
87 mem_alloc_node_params: bool = False
88 mem_free_node_params: bool = False
89 batch_mem_op_node_params: bool = False
90 extra_topo_info: bool = False
91 conditional_node_params: bool = False
93 def _to_flags(self) -> int:
94 """Convert options to CUDA driver API flags (internal use)."""
95 flags = 0 2mba
96 if self.verbose: 2mba
97 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 2mba
98 if self.runtime_types: 2mba
99 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a
100 if self.kernel_node_params: 2mba
101 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a
102 if self.memcpy_node_params: 2mba
103 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a
104 if self.memset_node_params: 2mba
105 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a
106 if self.host_node_params: 2mba
107 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a
108 if self.event_node_params: 2mba
109 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a
110 if self.ext_semas_signal_node_params: 2mba
111 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a
112 if self.ext_semas_wait_node_params: 2mba
113 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a
114 if self.kernel_node_attributes: 2mba
115 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 18a
116 if self.handles: 2mba
117 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 2mba
118 if self.mem_alloc_node_params: 2mba
119 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a
120 if self.mem_free_node_params: 2mba
121 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a
122 if self.batch_mem_op_node_params: 2mba
123 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a
124 if self.extra_topo_info: 2mba
125 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a
126 if self.conditional_node_params: 2mba
127 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a
128 return flags 2mba
131@dataclass
132class GraphCompleteOptions:
133 """Options for graph instantiation.
135 Attributes
136 ----------
137 auto_free_on_launch : bool, optional
138 Automatically free memory allocated in a graph before relaunching. (Default to False)
139 upload_stream : Stream, optional
140 Stream to use to automatically upload the graph after completion. (Default to None)
141 device_launch : bool, optional
142 Configure the graph to be launchable from the device. This flag can only
143 be used on platforms which support unified addressing. This flag cannot be
144 used in conjunction with auto_free_on_launch. (Default to False)
145 use_node_priority : bool, optional
146 Run the graph using the per-node priority attributes rather than the
147 priority of the stream it is launched into. (Default to False)
149 """
151 auto_free_on_launch: bool = False
152 upload_stream: Stream | None = None
153 device_launch: bool = False
154 use_node_priority: bool = False
157def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph:
158 cdef cydriver.CUgraphExec c_exec
159 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
160 if options: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
161 flags = 0 2. / _ : ; ` = ? @ { [ ] | hbibK F I M
162 if options.auto_free_on_launch: 2. / _ : ; ` = ? @ { [ ] | hbibK F I M
163 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 2. : ? [ hbibK F I M
164 if options.upload_stream: 2. / _ : ; ` = ? @ { [ ] | hbibK F I M
165 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1_`{|M
166 params.hUploadStream = options.upload_stream.handle 1_`{|M
167 if options.device_launch: 2. / _ : ; ` = ? @ { [ ] | hbibK F I M
168 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1=M
169 if options.use_node_priority: 2. / _ : ; ` = ? @ { [ ] | hbibK F I M
170 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 2/ ; @ ] hbibM
171 params.flags = flags 2. / _ : ; ` = ? @ { [ ] | hbibK F I M
173 py_exec = handle_return(driver.cuGraphInstantiateWithParams(h_graph, params)) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
174 # Check result_out before wrapping the exec: on a non-SUCCESS result the exec
175 # may be invalid, and Graph._init's RAII deleter would call cuGraphExecDestroy
176 # on it during the exception unwind below.
177 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
178 raise RuntimeError(
179 "Instantiation failed for an unexpected reason which is described in the return value of the function."
180 )
181 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
182 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
183 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
184 raise RuntimeError(
185 "Instantiation for device launch failed because the graph contained an unsupported operation."
186 )
187 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
188 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
189 elif ( 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
190 cy_binding_version() >= (12, 8, 0) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
191 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
192 ):
193 raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
194 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
195 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
197 c_exec = <cydriver.CUgraphExec><intptr_t>int(py_exec) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
198 return Graph._init(c_exec) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
201# Distinguishes the three kinds of GraphBuilder, which differ in how they
202# begin/end stream capture and whether they own the resulting CUgraph.
203# Each kind progresses through _CaptureState as follows:
204#
205# PRIMARY: NOT_STARTED -> CAPTURING -> ENDED
206# FORKED: CAPTURING (never transitions; joined and closed)
207# CONDITIONAL_BODY: NOT_STARTED -> CAPTURING -> ENDED
208#
209cdef enum _BuilderKind:
210 # PRIMARY: The top-level builder created by Device or Stream. Owns the
211 # captured CUgraph via an owning GraphHandle. Progresses through all three
212 # capture states; responsible for ending capture if destroyed early.
213 PRIMARY = 0
214 # FORKED: Created by split(). Captures on a private stream forked from the
215 # primary. Starts in CAPTURING state and never transitions; the user joins
216 # it back to the primary via join(), which closes the builder. Must NOT
217 # call cuStreamEndCapture (the driver requires all forked streams to be
218 # joined first).
219 FORKED = 1
220 # CONDITIONAL_BODY: Created by if_then/if_else/switch/while_loop. Captures
221 # into a non-owned body graph via cuStreamBeginCaptureToGraph. The body
222 # graph's lifetime is tied to a parent graph. Progresses through all three
223 # capture states like PRIMARY.
224 CONDITIONAL_BODY = 2
227# Tracks the capture lifecycle of a GraphBuilder.
228cdef enum _CaptureState:
229 CAPTURE_NOT_STARTED = 0
230 CAPTURING = 1
231 CAPTURE_ENDED = 2 # Finished, valid handle
232 CLOSED = 3 # No valid handle
235cdef class GraphBuilder:
236 """A graph under construction by stream capture.
238 A graph groups a set of CUDA kernels and other CUDA operations together and executes
239 them with a specified dependency tree. It speeds up the workflow by combining the
240 driver activities associated with CUDA kernel launches and CUDA API calls.
242 Directly creating a :obj:`~graph.GraphBuilder` is not supported due
243 to ambiguity. New graph builders should instead be created through a
244 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object.
246 """
248 def __init__(self):
249 raise NotImplementedError(
250 "directly creating a GraphBuilder object can be ambiguous. Please either "
251 "call Device.create_graph_builder() or stream.create_graph_builder()"
252 )
254 def __dealloc__(self):
255 GB_end_capture_if_needed(self, False) 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
257 @staticmethod
258 def _init(Stream stream):
259 cdef GraphBuilder self = GraphBuilder.__new__(GraphBuilder) 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
260 # _h_graph set by begin_building
261 self._h_stream = stream._h_stream 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
262 self._kind = PRIMARY 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
263 self._state = CAPTURE_NOT_STARTED 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
264 self._stream = stream 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
265 return self 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
267 def close(self):
268 """Destroy the graph builder."""
269 GB_end_capture_if_needed(self, True) 1UNCD0Obcdefghijklmnopqrstuva
270 self._h_graph.reset() 1UNCD0Obcdefghijklmnopqrstuva
271 self._h_stream.reset() 1UNCD0Obcdefghijklmnopqrstuva
272 self._state = CLOSED 1UNCD0Obcdefghijklmnopqrstuva
273 self._stream = None 1UNCD0Obcdefghijklmnopqrstuva
275 @property
276 def stream(self) -> Stream:
277 """Returns the stream associated with the graph builder."""
278 return self._stream 1EUNC-DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
280 @property
281 def is_join_required(self) -> bool:
282 """Returns True if this graph builder must be joined before building is ended."""
283 return self._kind == FORKED 1NCDbcdefghijklmnopqrstuva
285 def begin_building(self, mode: str | None = "relaxed") -> GraphBuilder:
286 """Begins the building process.
288 Build `mode` for controlling interaction with other API calls must be one of the following:
290 - `global` : Prohibit potentially unsafe operations across all streams in the process.
291 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread.
292 - `relaxed` : The local thread is not prohibited from potentially unsafe operations.
294 Parameters
295 ----------
296 mode : str, optional
297 Build mode to control the interaction with other API calls that are porentially unsafe.
298 Default set to use relaxed.
300 """
301 GB_check_open(self) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
302 if self._state != CAPTURE_NOT_STARTED: 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
303 if self._state == CAPTURING: 1*Z
304 raise RuntimeError("Graph builder is already building.") 1*
305 else:
306 raise RuntimeError("Cannot resume building after building has ended.") 1Z
307 cdef cydriver.CUstreamCaptureMode c_mode
308 if mode == "global": 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
309 c_mode = cydriver.CU_STREAM_CAPTURE_MODE_GLOBAL 1'(12K34)
310 elif mode == "thread_local": 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB!$#%QVRWFISXTYwJL56)Ma7
311 c_mode = cydriver.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 1$%VWIXY)
312 elif mode == "relaxed": 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB!#QRFSTwJL56)Ma7
313 c_mode = cydriver.CU_STREAM_CAPTURE_MODE_RELAXED 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB!#QRFSTwJL56)Ma7
314 else:
315 raise ValueError(f"Unsupported build mode: {mode}") 1)
317 cdef cydriver.CUstream c_stream = as_cu(self._h_stream) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
318 cdef cydriver.CUgraph c_graph
319 cdef cydriver.CUstreamCaptureStatus c_status
320 if self._kind == CONDITIONAL_BODY: 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
321 c_graph = as_cu(self._h_graph) 1bcdefghijklmnopqxrstuvyzABwa
322 with nogil: 1bcdefghijklmnopqxrstuvyzABwa
323 HANDLE_RETURN(cydriver.cuStreamBeginCaptureToGraph( 1bcdefghijklmnopqxrstuvyzABwa
324 c_stream, c_graph, NULL, NULL, 0, c_mode))
325 self._state = CAPTURING 1bcdefghijklmnopqxrstuvyzABwa
326 else:
327 with nogil: 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
328 HANDLE_RETURN(cydriver.cuStreamBeginCapture(c_stream, c_mode)) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
329 # Capture is active now; set CAPTURING before the calls below so a
330 # failure in _get_capture_info/create_graph_handle still lets
331 # cleanup end the capture rather than leaving the stream poisoned.
332 self._state = CAPTURING 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
333 with nogil: 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
334 # The driver rejects a NULL captureStatus_out, so pass a
335 # stack-local even though we only want the graph handle.
336 _get_capture_info(c_stream, &c_status, &c_graph) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
337 self._h_graph = create_graph_handle(c_graph) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
338 return self 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
340 @property
341 def is_building(self) -> bool:
342 """Returns True if the graph builder is currently building."""
343 GB_check_open(self) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
344 cdef cydriver.CUstream c_stream = as_cu(self._h_stream) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
345 cdef cydriver.CUstreamCaptureStatus status
346 with nogil: 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
347 _get_capture_info(c_stream, &status, NULL) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
348 if status == cydriver.CU_STREAM_CAPTURE_STATUS_NONE: 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
349 return False 1+
350 elif status == cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE:
351 return True 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
352 elif status == cydriver.CU_STREAM_CAPTURE_STATUS_INVALIDATED:
353 raise RuntimeError(
354 "Build process encountered an error and has been invalidated. Build process must now be ended."
355 )
356 else:
357 raise NotImplementedError(f"Unsupported capture status type received: {status}")
359 def end_building(self) -> GraphBuilder:
360 """Ends the building process."""
361 GB_check_open(self) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
362 if not self.is_building: 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
363 raise RuntimeError("Graph builder is not building.")
364 cdef cydriver.CUstream c_stream = as_cu(self._h_stream) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
365 cdef cydriver.CUgraph c_graph
366 with nogil: 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
367 HANDLE_RETURN(cydriver.cuStreamEndCapture(c_stream, &c_graph)) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
369 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to
370 # resume the build process after the first call to end_building()
371 self._state = CAPTURE_ENDED 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
372 return self 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
374 def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
375 """Completes the graph builder and returns the built :obj:`~graph.Graph` object.
377 Parameters
378 ----------
379 options : :obj:`~graph.GraphCompleteOptions`, optional
380 Customizable dataclass for the graph builder completion options.
382 Returns
383 -------
384 graph : :obj:`~graph.Graph`
385 The newly built graph.
387 """
388 GB_check_open(self) 1GH9EUNCDZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56M7
389 if self._state != CAPTURE_ENDED: 1GH9EUCDZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56M7
390 raise RuntimeError("Graph has not finished building.") 19
392 return _instantiate_graph(as_py(self._h_graph), options) 1GH9EUCDZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56M7
394 def debug_dot_print(self, path: str, options: GraphDebugPrintOptions | None = None) -> None:
395 """Generates a DOT debug file for the graph builder.
397 Parameters
398 ----------
399 path : str
400 File path to use for writting debug DOT output
401 options : :obj:`~graph.GraphDebugPrintOptions`, optional
402 Customizable dataclass for the debug print options.
404 """
405 GB_check_open(self) 1a
406 if self._state != CAPTURE_ENDED: 1a
407 raise RuntimeError("Graph has not finished building.")
408 cdef unsigned int c_flags = options._to_flags() if options else 0 1a
409 cdef cydriver.CUgraph c_graph = as_cu(self._h_graph) 1a
410 cdef bytes b_path = path.encode('utf-8') 1a
411 cdef const char* c_path = b_path 1a
412 with nogil: 1a
413 HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(c_graph, c_path, c_flags)) 1a
415 def split(self, count: int) -> tuple[GraphBuilder, ...]:
416 """Splits the original graph builder into multiple graph builders.
418 The new builders inherit work dependencies from the original builder.
419 The original builder is reused for the split and is returned first in the tuple.
421 Parameters
422 ----------
423 count : int
424 The number of graph builders to split the graph builder into.
426 Returns
427 -------
428 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
429 A tuple of split graph builders. The first graph builder in the tuple
430 is always the original graph builder.
432 """
433 if count < 2: 2N C D Dbb c d e f g h i j k l m n o p q r s t u v a
434 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1C
435 GB_check_open(self) 2N C D Dbb c d e f g h i j k l m n o p q r s t u v a
436 if self._state != CAPTURING: 2N C D Dbb c d e f g h i j k l m n o p q r s t u v a
437 raise RuntimeError("Graph builder must be building before it can be split.") 2Db
439 event = self._stream.record() 1NCDbcdefghijklmnopqrstuva
440 result = [self] 1NCDbcdefghijklmnopqrstuva
441 for i in range(count - 1): 1NCDbcdefghijklmnopqrstuva
442 stream = self._stream.device.create_stream() 1NCDbcdefghijklmnopqrstuva
443 stream.wait(event) 1NCDbcdefghijklmnopqrstuva
444 result.append(GB_init_forked(stream, self._h_graph)) 1NCDbcdefghijklmnopqrstuva
445 event.close() 1NCDbcdefghijklmnopqrstuva
446 return tuple(result) 1NCDbcdefghijklmnopqrstuva
448 @staticmethod
449 def join(*graph_builders: GraphBuilder) -> GraphBuilder:
450 """Joins multiple graph builders into a single graph builder.
452 The returned builder inherits work dependencies from the provided builders.
454 Parameters
455 ----------
456 *graph_builders : :obj:`~graph.GraphBuilder`
457 The graph builders to join.
459 Returns
460 -------
461 graph_builder : :obj:`~graph.GraphBuilder`
462 The newly joined graph builder.
464 """
465 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1NCDbcdefghijklmnopqrstuva
466 raise TypeError("All arguments must be GraphBuilder instances")
467 if len(graph_builders) < 2: 1NCDbcdefghijklmnopqrstuva
468 raise ValueError("Must join with at least two graph builders") 1C
470 # Discover the root builder others should join
471 root_idx = 0 1NCDbcdefghijklmnopqrstuva
472 for i, builder in enumerate(graph_builders): 1NCDbcdefghijklmnopqrstuva
473 if not builder.is_join_required: 1NCDbcdefghijklmnopqrstuva
474 root_idx = i 1NCDbcdefghijklmnopqrstuva
475 break 1NCDbcdefghijklmnopqrstuva
477 # Join all onto the root builder
478 root_bdr = graph_builders[root_idx] 1NCDbcdefghijklmnopqrstuva
479 for idx, builder in enumerate(graph_builders): 1NCDbcdefghijklmnopqrstuva
480 if idx == root_idx: 1NCDbcdefghijklmnopqrstuva
481 continue 1NCDbcdefghijklmnopqrstuva
482 root_bdr.stream.wait(builder.stream) 1NCDbcdefghijklmnopqrstuva
483 builder.close() 1NCDbcdefghijklmnopqrstuva
485 return root_bdr 1NCDbcdefghijklmnopqrstuva
487 def __cuda_stream__(self) -> tuple[int, int]:
488 """Return an instance of a __cuda_stream__ protocol."""
489 GB_check_open(self)
490 return self.stream.__cuda_stream__()
492 def _get_conditional_context(self) -> driver.CUcontext:
493 return self._stream.context.handle 1bcdefghijklmnopqxrstuvyzABwa
495 def create_condition(self, default_value: int | None = None) -> GraphCondition:
496 """Create a condition variable for use with conditional nodes.
498 The returned :class:`GraphCondition` object is passed to conditional-node
499 builder methods (:meth:`if_then`, :meth:`if_else`, :meth:`while_loop`,
500 :meth:`switch`). Its value is controlled at runtime by device code via
501 ``cudaGraphSetConditional``.
503 Parameters
504 ----------
505 default_value : int, optional
506 The default value to assign to the condition. If None, no
507 default is assigned.
509 Returns
510 -------
511 GraphCondition
512 A condition variable for controlling conditional execution.
513 """
514 GB_check_open(self) 1bcdefghijklmnopqxrstuvyzABwa
515 if cy_driver_version() < (12, 3, 0): 1bcdefghijklmnopqxrstuvyzABwa
516 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles")
517 if cy_binding_version() < (12, 3, 0): 1bcdefghijklmnopqxrstuvyzABwa
518 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles")
519 if default_value is not None: 1bcdefghijklmnopqxrstuvyzABwa
520 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1yzABw
521 else:
522 default_value = 0 1bcdefghijklmnopqxrstuva
523 flags = 0 1bcdefghijklmnopqxrstuva
525 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._stream.handle)) 1bcdefghijklmnopqxrstuvyzABwa
526 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqxrstuvyzABwa
527 raise RuntimeError("Cannot create a condition when graph is not being built")
529 raw_handle = handle_return( 1bcdefghijklmnopqxrstuvyzABwa
530 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 1bcdefghijklmnopqxrstuvyzABwa
531 )
532 return GraphCondition._from_handle(<cydriver.CUgraphConditionalHandle><intptr_t>int(raw_handle)) 1bcdefghijklmnopqxrstuvyzABwa
534 def if_then(self, condition: GraphCondition) -> GraphBuilder:
535 """Adds an if condition branch and returns a new graph builder for it.
537 The resulting if graph will only execute the branch if the
538 condition evaluates to true at runtime.
540 The new builder inherits work dependencies from the original builder.
542 Parameters
543 ----------
544 condition : :class:`~graph.GraphCondition`
545 The condition variable from :meth:`create_condition` controlling
546 whether the branch executes.
548 Returns
549 -------
550 graph_builder : :obj:`~graph.GraphBuilder`
551 The newly created conditional graph builder.
553 """
554 GB_check_open(self) 1bcdefghixra
555 if cy_driver_version() < (12, 3, 0): 1bcdefghixra
556 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if")
557 if cy_binding_version() < (12, 3, 0): 1bcdefghixra
558 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if")
559 if not isinstance(condition, GraphCondition): 1bcdefghixra
560 raise TypeError(
561 f"condition must be a GraphCondition object (from "
562 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
563 node_params = driver.CUgraphNodeParams() 1bcdefghixra
564 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghixra
565 node_params.conditional.handle = condition.handle 1bcdefghixra
566 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghixra
567 node_params.conditional.size = 1 1bcdefghixra
568 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghixra
569 return GB_cond_with_params(self, node_params)[0] 1bcdefghixra
571 def if_else(self, condition: GraphCondition) -> tuple[GraphBuilder, GraphBuilder]:
572 """Adds an if-else condition branch and returns new graph builders for both branches.
574 The resulting if graph will execute the branch if the condition
575 evaluates to true at runtime, otherwise the else branch will execute.
577 The new builders inherit work dependencies from the original builder.
579 Parameters
580 ----------
581 condition : :class:`~graph.GraphCondition`
582 The condition variable from :meth:`create_condition` controlling
583 which branch executes.
585 Returns
586 -------
587 graph_builders : tuple[:obj:`~graph.GraphBuilder`, :obj:`~graph.GraphBuilder`]
588 A tuple of two new graph builders, one for the if branch and one for the else branch.
590 """
591 GB_check_open(self) 1jklmnopq
592 if cy_driver_version() < (12, 8, 0): 1jklmnopq
593 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else")
594 if cy_binding_version() < (12, 8, 0): 1jklmnopq
595 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else")
596 if not isinstance(condition, GraphCondition): 1jklmnopq
597 raise TypeError(
598 f"condition must be a GraphCondition object (from "
599 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
600 node_params = driver.CUgraphNodeParams() 1jklmnopq
601 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq
602 node_params.conditional.handle = condition.handle 1jklmnopq
603 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq
604 node_params.conditional.size = 2 1jklmnopq
605 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq
606 return GB_cond_with_params(self, node_params) 1jklmnopq
608 def switch(self, condition: GraphCondition, count: int) -> tuple[GraphBuilder, ...]:
609 """Adds a switch condition branch and returns new graph builders for all cases.
611 The resulting switch graph will execute the branch whose case index
612 matches the value of the condition at runtime. If no match is found, no
613 branch will be executed.
615 The new builders inherit work dependencies from the original builder.
617 Parameters
618 ----------
619 condition : :class:`~graph.GraphCondition`
620 The condition variable from :meth:`create_condition` selecting
621 which case executes.
622 count : int
623 The number of cases to add to the switch conditional.
625 Returns
626 -------
627 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
628 A tuple of new graph builders, one for each branch.
630 """
631 GB_check_open(self) 1stuvw
632 if cy_driver_version() < (12, 8, 0): 1stuvw
633 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch")
634 if cy_binding_version() < (12, 8, 0): 1stuvw
635 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch")
636 if not isinstance(condition, GraphCondition): 1stuvw
637 raise TypeError(
638 f"condition must be a GraphCondition object (from "
639 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
640 node_params = driver.CUgraphNodeParams() 1stuvw
641 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1stuvw
642 node_params.conditional.handle = condition.handle 1stuvw
643 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1stuvw
644 node_params.conditional.size = count 1stuvw
645 node_params.conditional.ctx = self._get_conditional_context() 1stuvw
646 return GB_cond_with_params(self, node_params) 1stuvw
648 def while_loop(self, condition: GraphCondition) -> GraphBuilder:
649 """Adds a while loop and returns a new graph builder for it.
651 The resulting while loop graph will execute the branch repeatedly at runtime
652 until the condition evaluates to false.
654 The new builder inherits work dependencies from the original builder.
656 Parameters
657 ----------
658 condition : :class:`~graph.GraphCondition`
659 The condition variable from :meth:`create_condition` controlling
660 loop continuation.
662 Returns
663 -------
664 graph_builder : :obj:`~graph.GraphBuilder`
665 The newly created while loop graph builder.
667 """
668 GB_check_open(self) 1yzAB
669 if cy_driver_version() < (12, 3, 0): 1yzAB
670 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop")
671 if cy_binding_version() < (12, 3, 0): 1yzAB
672 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop")
673 if not isinstance(condition, GraphCondition): 1yzAB
674 raise TypeError(
675 f"condition must be a GraphCondition object (from "
676 f"GraphBuilder.create_condition()), got {type(condition).__name__}")
677 node_params = driver.CUgraphNodeParams() 1yzAB
678 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1yzAB
679 node_params.conditional.handle = condition.handle 1yzAB
680 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1yzAB
681 node_params.conditional.size = 1 1yzAB
682 node_params.conditional.ctx = self._get_conditional_context() 1yzAB
683 return GB_cond_with_params(self, node_params)[0] 1yzAB
685 def embed(self, GraphBuilder child):
686 """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node.
688 Parameters
689 ----------
690 child : :obj:`~graph.GraphBuilder`
691 The child graph builder. Must have finished building.
692 """
693 GB_check_open(self) 1E
694 if child._state != CAPTURE_ENDED: 1E
695 raise ValueError("Child graph has not finished building.")
697 if not self.is_building: 1E
698 raise ValueError("Parent graph is not being built.")
700 stream_handle = self._stream.handle 1E
701 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1E
702 driver.cuStreamGetCaptureInfo(stream_handle) 1E
703 )
705 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159
706 # for rationale
707 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1E
708 deps_info_update = [ 1E
709 [ 1E
710 handle_return( 1E
711 driver.cuGraphAddChildGraphNode( 1E
712 graph_out, *deps_info_trimmed, num_dependencies_out, as_py(child._h_graph) 1E
713 )
714 )
715 ]
716 ] + [None] * (len(deps_info_out) - 1) 1E
717 handle_return( 1E
718 driver.cuStreamUpdateCaptureDependencies( 1E
719 stream_handle, 1E
720 *deps_info_update, # dependencies, edgeData
721 1,
722 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1E
723 )
724 )
726 def callback(self, fn, *, user_data=None) -> None:
727 """Add a host callback to the graph during stream capture.
729 The callback runs on the host CPU when the graph reaches this point
730 in execution. Two modes are supported:
732 - **Python callable**: Pass any callable. The GIL is acquired
733 automatically. The callable must take no arguments; use closures
734 or ``functools.partial`` to bind state.
735 - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance.
736 The function receives a single ``void*`` argument (the
737 ``user_data``). The caller must keep the ctypes wrapper alive
738 for the lifetime of the graph.
740 .. warning::
742 Callbacks must not call CUDA API functions. Doing so may
743 deadlock or corrupt driver state.
745 Parameters
746 ----------
747 fn : callable or ctypes function pointer
748 The callback function.
749 user_data : int or bytes-like, optional
750 Only for ctypes function pointers. If ``int``, passed as a raw
751 pointer (caller manages lifetime). If bytes-like, the data is
752 copied and its lifetime is tied to the graph.
753 """
754 GB_check_open(self) 1GH
755 cdef Stream stream = self._stream 1GH
756 cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) 1GH
757 cdef cydriver.CUstreamCaptureStatus capture_status
758 cdef cydriver.CUgraph c_graph = NULL 1GH
760 with nogil: 1GH
761 _get_capture_info(c_stream, &capture_status, &c_graph) 1GH
763 if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GH
764 raise RuntimeError("Cannot add callback when graph is not being built")
766 cdef cydriver.CUhostFn c_fn
767 cdef void* c_user_data = NULL 1GH
768 _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) 1GH
770 with nogil: 1GH
771 HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) 1GH
774cdef inline int GB_check_open(GraphBuilder gb) except -1:
775 """Reject operations on a builder that has been closed.
777 A CLOSED builder has reset its stream and graph handles, so any method
778 that dereferences them would read a null handle (or, for the cached
779 Stream, a None typed as cdef Stream). Guarding here yields a clear error
780 instead.
781 """
782 if gb._state == CLOSED: 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
783 raise RuntimeError("Graph builder has been closed.") 1N
784 return 0 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
787cdef inline int GB_end_capture_if_needed(GraphBuilder gb, bint check_status) except -1 nogil:
788 """End an in-progress capture if this builder owns it.
790 Only a CAPTURING PRIMARY or CONDITIONAL_BODY builder owns the live
791 capture. A FORKED builder must not call cuStreamEndCapture: the driver
792 requires forked streams to be joined first.
794 check_status=True checks the driver return (close()); False ignores it
795 (__dealloc__).
796 """
797 cdef cydriver.CUgraph c_graph
798 cdef cydriver.CUresult err
799 cdef cydriver.CUstream c_stream
800 if gb._h_stream and gb._state == CAPTURING and gb._kind != FORKED: 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
801 c_stream = as_cu(gb._h_stream) 1N-
802 with nogil: 1N-
803 err = cydriver.cuStreamEndCapture(c_stream, &c_graph) 1N-
804 if check_status: 1N-
805 HANDLE_RETURN(err)
806 return 0 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7
809cdef inline GraphBuilder GB_init_forked(Stream stream, GraphHandle h_primary_graph):
810 cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) 1NCDbcdefghijklmnopqrstuva
811 # A FORKED builder captures into the primary's CUgraph. It holds the
812 # primary's GraphHandle so conditional bodies created on it (via
813 # GB_init_conditional -> create_graph_handle_ref(cond_graph, parent._h_graph))
814 # have a valid parent handle to pin.
815 gb._h_graph = h_primary_graph 1NCDbcdefghijklmnopqrstuva
816 gb._h_stream = stream._h_stream 1NCDbcdefghijklmnopqrstuva
817 gb._kind = FORKED 1NCDbcdefghijklmnopqrstuva
818 gb._state = CAPTURING 1NCDbcdefghijklmnopqrstuva
819 gb._stream = stream 1NCDbcdefghijklmnopqrstuva
820 return gb 1NCDbcdefghijklmnopqrstuva
823cdef inline GraphBuilder GB_init_conditional(Stream stream, cydriver.CUgraph cond_graph, GraphBuilder parent):
824 cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) 1bcdefghijklmnopqxrstuvyzABwa
825 gb._h_graph = create_graph_handle_ref(cond_graph, parent._h_graph) 1bcdefghijklmnopqxrstuvyzABwa
826 gb._h_stream = stream._h_stream 1bcdefghijklmnopqxrstuvyzABwa
827 gb._kind = CONDITIONAL_BODY 1bcdefghijklmnopqxrstuvyzABwa
828 gb._state = CAPTURE_NOT_STARTED 1bcdefghijklmnopqxrstuvyzABwa
829 gb._stream = stream 1bcdefghijklmnopqxrstuvyzABwa
830 return gb 1bcdefghijklmnopqxrstuvyzABwa
833cdef inline int _get_capture_info(
834 cydriver.CUstream stream,
835 cydriver.CUstreamCaptureStatus* status,
836 cydriver.CUgraph* graph) except?-1 nogil:
837 """Thin wrapper around ``cuStreamGetCaptureInfo`` that papers over the
838 CUDA 12 vs 13 signature change.
840 ``status`` must be non-NULL: the driver rejects ``captureStatus_out=NULL``
841 with ``CUDA_ERROR_INVALID_VALUE``. ``graph`` may be NULL when the caller
842 does not need the graph handle.
843 """
844 IF CUDA_CORE_BUILD_MAJOR >= 13:
845 return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7
846 stream, status, NULL, graph, NULL, NULL, NULL))
847 ELSE:
848 return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo(
849 stream, status, NULL, graph, NULL, NULL))
852cdef inline tuple GB_cond_with_params(GraphBuilder gb, node_params):
853 status, _, graph, *deps_info, num_dependencies = handle_return( 1bcdefghijklmnopqxrstuvyzABwa
854 driver.cuStreamGetCaptureInfo(gb._stream.handle) 1bcdefghijklmnopqxrstuvyzABwa
855 )
856 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqxrstuvyzABwa
857 raise RuntimeError("Cannot add conditional node when not actively capturing")
859 deps_info_update = [ 1bcdefghijklmnopqxrstuvyzABwa
860 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 1bcdefghijklmnopqxrstuvyzABwa
861 ] + [None] * (len(deps_info) - 1) 1bcdefghijklmnopqxrstuvyzABwa
863 handle_return( 1bcdefghijklmnopqxrstuvyzABwa
864 driver.cuStreamUpdateCaptureDependencies( 1bcdefghijklmnopqxrstuvyzABwa
865 gb._stream.handle, 1bcdefghijklmnopqxrstuvyzABwa
866 *deps_info_update, # dependencies, edgeData 1bcdefghijklmnopqxrstuvyzABwa
867 1, # numDependencies
868 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1bcdefghijklmnopqxrstuvyzABwa
869 )
870 )
872 return tuple( 1bcdefghijklmnopqxrstuvyzABwa
873 GB_init_conditional( 1bcdefghijklmnopqxrstuvyzABwa
874 gb._stream.device.create_stream(), 1bcdefghijklmnopqxrstuvyzABwa
875 <cydriver.CUgraph><intptr_t>int(node_params.conditional.phGraph_out[i]), 1bcdefghijklmnopqxrstuvyzABwa
876 gb, 1bcdefghijklmnopqxrstuvyzABwa
877 )
878 for i in range(node_params.conditional.size) 1bcdefghijklmnopqxrstuvyzABwa
879 )
882cdef class Graph:
883 """An executable graph.
885 A graph groups a set of CUDA kernels and other CUDA operations together and executes
886 them with a specified dependency tree. It speeds up the workflow by combining the
887 driver activities associated with CUDA kernel launches and CUDA API calls.
889 Graphs must be built using a :obj:`~graph.GraphBuilder` object.
891 """
893 def __init__(self):
894 raise RuntimeError("directly constructing a Graph instance is not supported")
896 @staticmethod
897 cdef Graph _init(cydriver.CUgraphExec graph_exec):
898 cdef Graph self = Graph.__new__(Graph) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
899 self._h_graph_exec = create_graph_exec_handle(graph_exec) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
900 return self 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7
902 def close(self) -> None:
903 """Destroy the graph."""
904 self._h_graph_exec.reset() 1U0M
906 @property
907 def handle(self) -> driver.CUgraphExec:
908 """Return the underlying ``CUgraphExec`` object.
910 .. caution::
912 This handle is a Python object. To get the memory address of the underlying C
913 handle, call ``int()`` on the returned object.
915 """
916 return as_py(self._h_graph_exec) 1U
918 def update(self, source: "GraphBuilder | GraphDefinition") -> None:
919 """Update the graph using a new graph definition.
921 The topology of the provided source must be identical to this graph.
923 Parameters
924 ----------
925 source : :obj:`~graph.GraphBuilder` or :obj:`~graph.GraphDefinition`
926 The graph definition to update from. A GraphBuilder must have
927 finished building.
929 """
930 from cuda.core.graph import GraphDefinition 1OwJ^L56
932 cdef cydriver.CUgraph cu_graph
933 cdef cydriver.CUgraphExec cu_exec = as_cu(self._h_graph_exec) 1OwJ^L56
935 if isinstance(source, GraphBuilder): 1OwJ^L56
936 if (<GraphBuilder>source)._state == CLOSED: 1OwJL5
937 raise ValueError("Source graph builder has been closed.") 1O
938 if (<GraphBuilder>source)._state != CAPTURE_ENDED: 1wJL5
939 raise ValueError("Graph has not finished building.") 15
940 cu_graph = as_cu((<GraphBuilder>source)._h_graph) 1wJL
941 elif isinstance(source, GraphDefinition): 1^6
942 cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle) 1^
943 else:
944 raise TypeError( 16
945 f"expected GraphBuilder or GraphDefinition, got {type(source).__name__}") 16
947 cdef cydriver.CUgraphExecUpdateResultInfo result_info
948 cdef cydriver.CUresult err
949 with nogil: 1wJ^L
950 err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) 1wJ^L
951 if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: 1wJ^L
952 reason = driver.CUgraphExecUpdateResult(result_info.result) 1L
953 msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" 1L
954 raise CUDAError(msg) 1L
955 HANDLE_RETURN(err) 1wJ^
957 def upload(self, stream: Stream) -> None:
958 """Uploads the graph in a stream.
960 Parameters
961 ----------
962 stream : :obj:`~_stream.Stream`
963 The stream in which to upload the graph
965 """
966 cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) 2C P } . ~ / abbbcbdb: eb; = ? fb@ [ gb] jbkblb1 Q V 2 R W K F I 3 S X 4 T Y
967 cdef cydriver.CUstream c_stream = <cydriver.CUstream><intptr_t>int(stream.handle) 2C P } . ~ / abbbcbdb: eb; = ? fb@ [ gb] jbkblb1 Q V 2 R W K F I 3 S X 4 T Y
968 with nogil: 2C P } . ~ / abbbcbdb: eb; = ? fb@ [ gb] jbkblb1 Q V 2 R W K F I 3 S X 4 T Y
969 HANDLE_RETURN(cydriver.cuGraphUpload(c_exec, c_stream)) 2C P } . ~ / abbbcbdb: eb; = ? fb@ [ gb] jbkblb1 Q V 2 R W K F I 3 S X 4 T Y
971 def launch(self, stream: Stream) -> None:
972 """Launches the graph in a stream.
974 Parameters
975 ----------
976 stream : :obj:`~_stream.Stream`
977 The stream in which to launch the graph.
979 """
980 cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) 2G H E C Z P b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | nbjbkbobpbqbrblbsbtbubvbwbxb1 Q V 2 R W K F I 3 S X 4 T Y w J ^ 7
981 cdef cydriver.CUstream c_stream = <cydriver.CUstream><intptr_t>int(stream.handle) 2G H E C Z P b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | nbjbkbobpbqbrblbsbtbubvbwbxb1 Q V 2 R W K F I 3 S X 4 T Y w J ^ 7
982 with nogil: 2G H E C Z P b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | nbjbkbobpbqbrblbsbtbubvbwbxb1 Q V 2 R W K F I 3 S X 4 T Y w J ^ 7
983 HANDLE_RETURN(cydriver.cuGraphLaunch(c_exec, c_stream)) 2G H E C Z P b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | nbjbkbobpbqbrblbsbtbubvbwbxb1 Q V 2 R W K F I 3 S X 4 T Y w J ^ 7