Coverage for cuda/core/graph/_graph_builder.pyx: 89.32%

412 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-03 01:38 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5from dataclasses import dataclass 

6from typing import TYPE_CHECKING 

7  

8from libc.stdint cimport intptr_t 

9  

10from cuda.bindings cimport cydriver 

11  

12from cuda.core.graph._graph_definition cimport GraphCondition 

13from cuda.core.graph._utils cimport _attach_host_callback_to_graph 

14from cuda.core._resource_handles cimport ( 

15 GraphHandle, 

16 as_cu, as_py, 

17 create_graph_exec_handle, create_graph_handle, create_graph_handle_ref, 

18) 

19from cuda.core._stream cimport Stream 

20from cuda.core._utils.cuda_utils cimport HANDLE_RETURN 

21from cuda.core._utils.version cimport cy_binding_version, cy_driver_version 

22  

23from cuda.core._utils.cuda_utils import ( 

24 CUDAError, 

25 driver, 

26 handle_return, 

27) 

28  

29if TYPE_CHECKING: 

30 from cuda.core.graph._graph_definition import GraphDefinition 

31  

32__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions'] 

33  

34  

35@dataclass 

36class GraphDebugPrintOptions: 

37 """Options for debug_dot_print(). 

38  

39 Attributes 

40 ---------- 

41 verbose : bool 

42 Output all debug data as if every debug flag is enabled (Default to False) 

43 runtime_types : bool 

44 Use CUDA Runtime structures for output (Default to False) 

45 kernel_node_params : bool 

46 Adds kernel parameter values to output (Default to False) 

47 memcpy_node_params : bool 

48 Adds memcpy parameter values to output (Default to False) 

49 memset_node_params : bool 

50 Adds memset parameter values to output (Default to False) 

51 host_node_params : bool 

52 Adds host parameter values to output (Default to False) 

53 event_node_params : bool 

54 Adds event parameter values to output (Default to False) 

55 ext_semas_signal_node_params : bool 

56 Adds external semaphore signal parameter values to output (Default to False) 

57 ext_semas_wait_node_params : bool 

58 Adds external semaphore wait parameter values to output (Default to False) 

59 kernel_node_attributes : bool 

60 Adds kernel node attributes to output (Default to False) 

61 handles : bool 

62 Adds node handles and every kernel function handle to output (Default to False) 

63 mem_alloc_node_params : bool 

64 Adds memory alloc parameter values to output (Default to False) 

65 mem_free_node_params : bool 

66 Adds memory free parameter values to output (Default to False) 

67 batch_mem_op_node_params : bool 

68 Adds batch mem op parameter values to output (Default to False) 

69 extra_topo_info : bool 

70 Adds edge numbering information (Default to False) 

71 conditional_node_params : bool 

72 Adds conditional node parameter values to output (Default to False) 

73  

74 """ 

75  

76 verbose: bool = False 

77 runtime_types: bool = False 

78 kernel_node_params: bool = False 

79 memcpy_node_params: bool = False 

80 memset_node_params: bool = False 

81 host_node_params: bool = False 

82 event_node_params: bool = False 

83 ext_semas_signal_node_params: bool = False 

84 ext_semas_wait_node_params: bool = False 

85 kernel_node_attributes: bool = False 

86 handles: bool = False 

87 mem_alloc_node_params: bool = False 

88 mem_free_node_params: bool = False 

89 batch_mem_op_node_params: bool = False 

90 extra_topo_info: bool = False 

91 conditional_node_params: bool = False 

92  

93 def _to_flags(self) -> int: 

94 """Convert options to CUDA driver API flags (internal use).""" 

95 flags = 0 2mba

96 if self.verbose: 2mba

97 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 2mba

98 if self.runtime_types: 2mba

99 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a

100 if self.kernel_node_params: 2mba

101 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a

102 if self.memcpy_node_params: 2mba

103 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a

104 if self.memset_node_params: 2mba

105 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a

106 if self.host_node_params: 2mba

107 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a

108 if self.event_node_params: 2mba

109 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a

110 if self.ext_semas_signal_node_params: 2mba

111 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a

112 if self.ext_semas_wait_node_params: 2mba

113 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a

114 if self.kernel_node_attributes: 2mba

115 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 18a

116 if self.handles: 2mba

117 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 2mba

118 if self.mem_alloc_node_params: 2mba

119 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a

120 if self.mem_free_node_params: 2mba

121 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a

122 if self.batch_mem_op_node_params: 2mba

123 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a

124 if self.extra_topo_info: 2mba

125 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a

126 if self.conditional_node_params: 2mba

127 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a

128 return flags 2mba

129  

130  

131@dataclass 

132class GraphCompleteOptions: 

133 """Options for graph instantiation. 

134  

135 Attributes 

136 ---------- 

137 auto_free_on_launch : bool, optional 

138 Automatically free memory allocated in a graph before relaunching. (Default to False) 

139 upload_stream : Stream, optional 

140 Stream to use to automatically upload the graph after completion. (Default to None) 

141 device_launch : bool, optional 

142 Configure the graph to be launchable from the device. This flag can only 

143 be used on platforms which support unified addressing. This flag cannot be 

144 used in conjunction with auto_free_on_launch. (Default to False) 

145 use_node_priority : bool, optional 

146 Run the graph using the per-node priority attributes rather than the 

147 priority of the stream it is launched into. (Default to False) 

148  

149 """ 

150  

151 auto_free_on_launch: bool = False 

152 upload_stream: Stream | None = None 

153 device_launch: bool = False 

154 use_node_priority: bool = False 

155  

156  

157def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph: 

158 cdef cydriver.CUgraphExec c_exec 

159 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

160 if options: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

161 flags = 0 2. / _ : ; ` = ? @ { [ ] | hbibK F I M

162 if options.auto_free_on_launch: 2. / _ : ; ` = ? @ { [ ] | hbibK F I M

163 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 2. : ? [ hbibK F I M

164 if options.upload_stream: 2. / _ : ; ` = ? @ { [ ] | hbibK F I M

165 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 1_`{|M

166 params.hUploadStream = options.upload_stream.handle 1_`{|M

167 if options.device_launch: 2. / _ : ; ` = ? @ { [ ] | hbibK F I M

168 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1=M

169 if options.use_node_priority: 2. / _ : ; ` = ? @ { [ ] | hbibK F I M

170 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 2/ ; @ ] hbibM

171 params.flags = flags 2. / _ : ; ` = ? @ { [ ] | hbibK F I M

172  

173 py_exec = handle_return(driver.cuGraphInstantiateWithParams(h_graph, params)) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

174 # Check result_out before wrapping the exec: on a non-SUCCESS result the exec 

175 # may be invalid, and Graph._init's RAII deleter would call cuGraphExecDestroy 

176 # on it during the exception unwind below. 

177 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

178 raise RuntimeError( 

179 "Instantiation failed for an unexpected reason which is described in the return value of the function." 

180 ) 

181 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

182 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.") 

183 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

184 raise RuntimeError( 

185 "Instantiation for device launch failed because the graph contained an unsupported operation." 

186 ) 

187 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

188 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.") 

189 elif ( 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

190 cy_binding_version() >= (12, 8, 0) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

191 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

192 ): 

193 raise RuntimeError("One or more conditional handles are not associated with conditional builders.") 

194 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

195 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") 

196  

197 c_exec = <cydriver.CUgraphExec><intptr_t>int(py_exec) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

198 return Graph._init(c_exec) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

199  

200  

201# Distinguishes the three kinds of GraphBuilder, which differ in how they 

202# begin/end stream capture and whether they own the resulting CUgraph. 

203# Each kind progresses through _CaptureState as follows: 

204# 

205# PRIMARY: NOT_STARTED -> CAPTURING -> ENDED 

206# FORKED: CAPTURING (never transitions; joined and closed) 

207# CONDITIONAL_BODY: NOT_STARTED -> CAPTURING -> ENDED 

208# 

209cdef enum _BuilderKind: 

210 # PRIMARY: The top-level builder created by Device or Stream. Owns the 

211 # captured CUgraph via an owning GraphHandle. Progresses through all three 

212 # capture states; responsible for ending capture if destroyed early. 

213 PRIMARY = 0 

214 # FORKED: Created by split(). Captures on a private stream forked from the 

215 # primary. Starts in CAPTURING state and never transitions; the user joins 

216 # it back to the primary via join(), which closes the builder. Must NOT 

217 # call cuStreamEndCapture (the driver requires all forked streams to be 

218 # joined first). 

219 FORKED = 1 

220 # CONDITIONAL_BODY: Created by if_then/if_else/switch/while_loop. Captures 

221 # into a non-owned body graph via cuStreamBeginCaptureToGraph. The body 

222 # graph's lifetime is tied to a parent graph. Progresses through all three 

223 # capture states like PRIMARY. 

224 CONDITIONAL_BODY = 2 

225  

226  

227# Tracks the capture lifecycle of a GraphBuilder. 

228cdef enum _CaptureState: 

229 CAPTURE_NOT_STARTED = 0 

230 CAPTURING = 1 

231 CAPTURE_ENDED = 2 # Finished, valid handle 

232 CLOSED = 3 # No valid handle 

233  

234  

235cdef class GraphBuilder: 

236 """A graph under construction by stream capture. 

237  

238 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

239 them with a specified dependency tree. It speeds up the workflow by combining the 

240 driver activities associated with CUDA kernel launches and CUDA API calls. 

241  

242 Directly creating a :obj:`~graph.GraphBuilder` is not supported due 

243 to ambiguity. New graph builders should instead be created through a 

244 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object. 

245  

246 """ 

247  

248 def __init__(self): 

249 raise NotImplementedError( 

250 "directly creating a GraphBuilder object can be ambiguous. Please either " 

251 "call Device.create_graph_builder() or stream.create_graph_builder()" 

252 ) 

253  

254 def __dealloc__(self): 

255 GB_end_capture_if_needed(self, False) 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

256  

257 @staticmethod 

258 def _init(Stream stream): 

259 cdef GraphBuilder self = GraphBuilder.__new__(GraphBuilder) 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

260 # _h_graph set by begin_building 

261 self._h_stream = stream._h_stream 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

262 self._kind = PRIMARY 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

263 self._state = CAPTURE_NOT_STARTED 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

264 self._stream = stream 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

265 return self 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

266  

267 def close(self): 

268 """Destroy the graph builder.""" 

269 GB_end_capture_if_needed(self, True) 1UNCD0Obcdefghijklmnopqrstuva

270 self._h_graph.reset() 1UNCD0Obcdefghijklmnopqrstuva

271 self._h_stream.reset() 1UNCD0Obcdefghijklmnopqrstuva

272 self._state = CLOSED 1UNCD0Obcdefghijklmnopqrstuva

273 self._stream = None 1UNCD0Obcdefghijklmnopqrstuva

274  

275 @property 

276 def stream(self) -> Stream: 

277 """Returns the stream associated with the graph builder.""" 

278 return self._stream 1EUNC-DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

279  

280 @property 

281 def is_join_required(self) -> bool: 

282 """Returns True if this graph builder must be joined before building is ended.""" 

283 return self._kind == FORKED 1NCDbcdefghijklmnopqrstuva

284  

285 def begin_building(self, mode: str | None = "relaxed") -> GraphBuilder: 

286 """Begins the building process. 

287  

288 Build `mode` for controlling interaction with other API calls must be one of the following: 

289  

290 - `global` : Prohibit potentially unsafe operations across all streams in the process. 

291 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread. 

292 - `relaxed` : The local thread is not prohibited from potentially unsafe operations. 

293  

294 Parameters 

295 ---------- 

296 mode : str, optional 

297 Build mode to control the interaction with other API calls that are porentially unsafe. 

298 Default set to use relaxed. 

299  

300 """ 

301 GB_check_open(self) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

302 if self._state != CAPTURE_NOT_STARTED: 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

303 if self._state == CAPTURING: 1*Z

304 raise RuntimeError("Graph builder is already building.") 1*

305 else: 

306 raise RuntimeError("Cannot resume building after building has ended.") 1Z

307 cdef cydriver.CUstreamCaptureMode c_mode 

308 if mode == "global": 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

309 c_mode = cydriver.CU_STREAM_CAPTURE_MODE_GLOBAL 1'(12K34)

310 elif mode == "thread_local": 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB!$#%QVRWFISXTYwJL56)Ma7

311 c_mode = cydriver.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 1$%VWIXY)

312 elif mode == "relaxed": 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB!#QRFSTwJL56)Ma7

313 c_mode = cydriver.CU_STREAM_CAPTURE_MODE_RELAXED 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB!#QRFSTwJL56)Ma7

314 else: 

315 raise ValueError(f"Unsupported build mode: {mode}") 1)

316  

317 cdef cydriver.CUstream c_stream = as_cu(self._h_stream) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

318 cdef cydriver.CUgraph c_graph 

319 cdef cydriver.CUstreamCaptureStatus c_status 

320 if self._kind == CONDITIONAL_BODY: 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

321 c_graph = as_cu(self._h_graph) 1bcdefghijklmnopqxrstuvyzABwa

322 with nogil: 1bcdefghijklmnopqxrstuvyzABwa

323 HANDLE_RETURN(cydriver.cuStreamBeginCaptureToGraph( 1bcdefghijklmnopqxrstuvyzABwa

324 c_stream, c_graph, NULL, NULL, 0, c_mode)) 

325 self._state = CAPTURING 1bcdefghijklmnopqxrstuvyzABwa

326 else: 

327 with nogil: 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

328 HANDLE_RETURN(cydriver.cuStreamBeginCapture(c_stream, c_mode)) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

329 # Capture is active now; set CAPTURING before the calls below so a 

330 # failure in _get_capture_info/create_graph_handle still lets 

331 # cleanup end the capture rather than leaving the stream poisoned. 

332 self._state = CAPTURING 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

333 with nogil: 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

334 # The driver rejects a NULL captureStatus_out, so pass a 

335 # stack-local even though we only want the graph handle. 

336 _get_capture_info(c_stream, &c_status, &c_graph) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

337 self._h_graph = create_graph_handle(c_graph) 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

338 return self 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

339  

340 @property 

341 def is_building(self) -> bool: 

342 """Returns True if the graph builder is currently building.""" 

343 GB_check_open(self) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

344 cdef cydriver.CUstream c_stream = as_cu(self._h_stream) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

345 cdef cydriver.CUstreamCaptureStatus status 

346 with nogil: 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

347 _get_capture_info(c_stream, &status, NULL) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

348 if status == cydriver.CU_STREAM_CAPTURE_STATUS_NONE: 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

349 return False 1+

350 elif status == cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: 

351 return True 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

352 elif status == cydriver.CU_STREAM_CAPTURE_STATUS_INVALIDATED: 

353 raise RuntimeError( 

354 "Build process encountered an error and has been invalidated. Build process must now be ended." 

355 ) 

356 else: 

357 raise NotImplementedError(f"Unsupported capture status type received: {status}") 

358  

359 def end_building(self) -> GraphBuilder: 

360 """Ends the building process.""" 

361 GB_check_open(self) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

362 if not self.is_building: 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

363 raise RuntimeError("Graph builder is not building.") 

364 cdef cydriver.CUstream c_stream = as_cu(self._h_stream) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

365 cdef cydriver.CUgraph c_graph 

366 with nogil: 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

367 HANDLE_RETURN(cydriver.cuStreamEndCapture(c_stream, &c_graph)) 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

368  

369 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to 

370 # resume the build process after the first call to end_building() 

371 self._state = CAPTURE_ENDED 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

372 return self 1*GH9EU,C+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

373  

374 def complete(self, options: GraphCompleteOptions | None = None) -> Graph: 

375 """Completes the graph builder and returns the built :obj:`~graph.Graph` object. 

376  

377 Parameters 

378 ---------- 

379 options : :obj:`~graph.GraphCompleteOptions`, optional 

380 Customizable dataclass for the graph builder completion options. 

381  

382 Returns 

383 ------- 

384 graph : :obj:`~graph.Graph` 

385 The newly built graph. 

386  

387 """ 

388 GB_check_open(self) 1GH9EUNCDZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56M7

389 if self._state != CAPTURE_ENDED: 1GH9EUCDZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56M7

390 raise RuntimeError("Graph has not finished building.") 19

391  

392 return _instantiate_graph(as_py(self._h_graph), options) 1GH9EUCDZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56M7

393  

394 def debug_dot_print(self, path: str, options: GraphDebugPrintOptions | None = None) -> None: 

395 """Generates a DOT debug file for the graph builder. 

396  

397 Parameters 

398 ---------- 

399 path : str 

400 File path to use for writting debug DOT output 

401 options : :obj:`~graph.GraphDebugPrintOptions`, optional 

402 Customizable dataclass for the debug print options. 

403  

404 """ 

405 GB_check_open(self) 1a

406 if self._state != CAPTURE_ENDED: 1a

407 raise RuntimeError("Graph has not finished building.") 

408 cdef unsigned int c_flags = options._to_flags() if options else 0 1a

409 cdef cydriver.CUgraph c_graph = as_cu(self._h_graph) 1a

410 cdef bytes b_path = path.encode('utf-8') 1a

411 cdef const char* c_path = b_path 1a

412 with nogil: 1a

413 HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(c_graph, c_path, c_flags)) 1a

414  

415 def split(self, count: int) -> tuple[GraphBuilder, ...]: 

416 """Splits the original graph builder into multiple graph builders. 

417  

418 The new builders inherit work dependencies from the original builder. 

419 The original builder is reused for the split and is returned first in the tuple. 

420  

421 Parameters 

422 ---------- 

423 count : int 

424 The number of graph builders to split the graph builder into. 

425  

426 Returns 

427 ------- 

428 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...] 

429 A tuple of split graph builders. The first graph builder in the tuple 

430 is always the original graph builder. 

431  

432 """ 

433 if count < 2: 2N C D Dbb c d e f g h i j k l m n o p q r s t u v a

434 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1C

435 GB_check_open(self) 2N C D Dbb c d e f g h i j k l m n o p q r s t u v a

436 if self._state != CAPTURING: 2N C D Dbb c d e f g h i j k l m n o p q r s t u v a

437 raise RuntimeError("Graph builder must be building before it can be split.") 2Db

438  

439 event = self._stream.record() 1NCDbcdefghijklmnopqrstuva

440 result = [self] 1NCDbcdefghijklmnopqrstuva

441 for i in range(count - 1): 1NCDbcdefghijklmnopqrstuva

442 stream = self._stream.device.create_stream() 1NCDbcdefghijklmnopqrstuva

443 stream.wait(event) 1NCDbcdefghijklmnopqrstuva

444 result.append(GB_init_forked(stream, self._h_graph)) 1NCDbcdefghijklmnopqrstuva

445 event.close() 1NCDbcdefghijklmnopqrstuva

446 return tuple(result) 1NCDbcdefghijklmnopqrstuva

447  

448 @staticmethod 

449 def join(*graph_builders: GraphBuilder) -> GraphBuilder: 

450 """Joins multiple graph builders into a single graph builder. 

451  

452 The returned builder inherits work dependencies from the provided builders. 

453  

454 Parameters 

455 ---------- 

456 *graph_builders : :obj:`~graph.GraphBuilder` 

457 The graph builders to join. 

458  

459 Returns 

460 ------- 

461 graph_builder : :obj:`~graph.GraphBuilder` 

462 The newly joined graph builder. 

463  

464 """ 

465 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1NCDbcdefghijklmnopqrstuva

466 raise TypeError("All arguments must be GraphBuilder instances") 

467 if len(graph_builders) < 2: 1NCDbcdefghijklmnopqrstuva

468 raise ValueError("Must join with at least two graph builders") 1C

469  

470 # Discover the root builder others should join 

471 root_idx = 0 1NCDbcdefghijklmnopqrstuva

472 for i, builder in enumerate(graph_builders): 1NCDbcdefghijklmnopqrstuva

473 if not builder.is_join_required: 1NCDbcdefghijklmnopqrstuva

474 root_idx = i 1NCDbcdefghijklmnopqrstuva

475 break 1NCDbcdefghijklmnopqrstuva

476  

477 # Join all onto the root builder 

478 root_bdr = graph_builders[root_idx] 1NCDbcdefghijklmnopqrstuva

479 for idx, builder in enumerate(graph_builders): 1NCDbcdefghijklmnopqrstuva

480 if idx == root_idx: 1NCDbcdefghijklmnopqrstuva

481 continue 1NCDbcdefghijklmnopqrstuva

482 root_bdr.stream.wait(builder.stream) 1NCDbcdefghijklmnopqrstuva

483 builder.close() 1NCDbcdefghijklmnopqrstuva

484  

485 return root_bdr 1NCDbcdefghijklmnopqrstuva

486  

487 def __cuda_stream__(self) -> tuple[int, int]: 

488 """Return an instance of a __cuda_stream__ protocol.""" 

489 GB_check_open(self) 

490 return self.stream.__cuda_stream__() 

491  

492 def _get_conditional_context(self) -> driver.CUcontext: 

493 return self._stream.context.handle 1bcdefghijklmnopqxrstuvyzABwa

494  

495 def create_condition(self, default_value: int | None = None) -> GraphCondition: 

496 """Create a condition variable for use with conditional nodes. 

497  

498 The returned :class:`GraphCondition` object is passed to conditional-node 

499 builder methods (:meth:`if_then`, :meth:`if_else`, :meth:`while_loop`, 

500 :meth:`switch`). Its value is controlled at runtime by device code via 

501 ``cudaGraphSetConditional``. 

502  

503 Parameters 

504 ---------- 

505 default_value : int, optional 

506 The default value to assign to the condition. If None, no 

507 default is assigned. 

508  

509 Returns 

510 ------- 

511 GraphCondition 

512 A condition variable for controlling conditional execution. 

513 """ 

514 GB_check_open(self) 1bcdefghijklmnopqxrstuvyzABwa

515 if cy_driver_version() < (12, 3, 0): 1bcdefghijklmnopqxrstuvyzABwa

516 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles") 

517 if cy_binding_version() < (12, 3, 0): 1bcdefghijklmnopqxrstuvyzABwa

518 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles") 

519 if default_value is not None: 1bcdefghijklmnopqxrstuvyzABwa

520 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1yzABw

521 else: 

522 default_value = 0 1bcdefghijklmnopqxrstuva

523 flags = 0 1bcdefghijklmnopqxrstuva

524  

525 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._stream.handle)) 1bcdefghijklmnopqxrstuvyzABwa

526 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqxrstuvyzABwa

527 raise RuntimeError("Cannot create a condition when graph is not being built") 

528  

529 raw_handle = handle_return( 1bcdefghijklmnopqxrstuvyzABwa

530 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 1bcdefghijklmnopqxrstuvyzABwa

531 ) 

532 return GraphCondition._from_handle(<cydriver.CUgraphConditionalHandle><intptr_t>int(raw_handle)) 1bcdefghijklmnopqxrstuvyzABwa

533  

534 def if_then(self, condition: GraphCondition) -> GraphBuilder: 

535 """Adds an if condition branch and returns a new graph builder for it. 

536  

537 The resulting if graph will only execute the branch if the 

538 condition evaluates to true at runtime. 

539  

540 The new builder inherits work dependencies from the original builder. 

541  

542 Parameters 

543 ---------- 

544 condition : :class:`~graph.GraphCondition` 

545 The condition variable from :meth:`create_condition` controlling 

546 whether the branch executes. 

547  

548 Returns 

549 ------- 

550 graph_builder : :obj:`~graph.GraphBuilder` 

551 The newly created conditional graph builder. 

552  

553 """ 

554 GB_check_open(self) 1bcdefghixra

555 if cy_driver_version() < (12, 3, 0): 1bcdefghixra

556 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if") 

557 if cy_binding_version() < (12, 3, 0): 1bcdefghixra

558 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if") 

559 if not isinstance(condition, GraphCondition): 1bcdefghixra

560 raise TypeError( 

561 f"condition must be a GraphCondition object (from " 

562 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

563 node_params = driver.CUgraphNodeParams() 1bcdefghixra

564 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghixra

565 node_params.conditional.handle = condition.handle 1bcdefghixra

566 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghixra

567 node_params.conditional.size = 1 1bcdefghixra

568 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghixra

569 return GB_cond_with_params(self, node_params)[0] 1bcdefghixra

570  

571 def if_else(self, condition: GraphCondition) -> tuple[GraphBuilder, GraphBuilder]: 

572 """Adds an if-else condition branch and returns new graph builders for both branches. 

573  

574 The resulting if graph will execute the branch if the condition 

575 evaluates to true at runtime, otherwise the else branch will execute. 

576  

577 The new builders inherit work dependencies from the original builder. 

578  

579 Parameters 

580 ---------- 

581 condition : :class:`~graph.GraphCondition` 

582 The condition variable from :meth:`create_condition` controlling 

583 which branch executes. 

584  

585 Returns 

586 ------- 

587 graph_builders : tuple[:obj:`~graph.GraphBuilder`, :obj:`~graph.GraphBuilder`] 

588 A tuple of two new graph builders, one for the if branch and one for the else branch. 

589  

590 """ 

591 GB_check_open(self) 1jklmnopq

592 if cy_driver_version() < (12, 8, 0): 1jklmnopq

593 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else") 

594 if cy_binding_version() < (12, 8, 0): 1jklmnopq

595 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else") 

596 if not isinstance(condition, GraphCondition): 1jklmnopq

597 raise TypeError( 

598 f"condition must be a GraphCondition object (from " 

599 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

600 node_params = driver.CUgraphNodeParams() 1jklmnopq

601 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq

602 node_params.conditional.handle = condition.handle 1jklmnopq

603 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq

604 node_params.conditional.size = 2 1jklmnopq

605 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq

606 return GB_cond_with_params(self, node_params) 1jklmnopq

607  

608 def switch(self, condition: GraphCondition, count: int) -> tuple[GraphBuilder, ...]: 

609 """Adds a switch condition branch and returns new graph builders for all cases. 

610  

611 The resulting switch graph will execute the branch whose case index 

612 matches the value of the condition at runtime. If no match is found, no 

613 branch will be executed. 

614  

615 The new builders inherit work dependencies from the original builder. 

616  

617 Parameters 

618 ---------- 

619 condition : :class:`~graph.GraphCondition` 

620 The condition variable from :meth:`create_condition` selecting 

621 which case executes. 

622 count : int 

623 The number of cases to add to the switch conditional. 

624  

625 Returns 

626 ------- 

627 graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...] 

628 A tuple of new graph builders, one for each branch. 

629  

630 """ 

631 GB_check_open(self) 1stuvw

632 if cy_driver_version() < (12, 8, 0): 1stuvw

633 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch") 

634 if cy_binding_version() < (12, 8, 0): 1stuvw

635 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch") 

636 if not isinstance(condition, GraphCondition): 1stuvw

637 raise TypeError( 

638 f"condition must be a GraphCondition object (from " 

639 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

640 node_params = driver.CUgraphNodeParams() 1stuvw

641 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1stuvw

642 node_params.conditional.handle = condition.handle 1stuvw

643 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1stuvw

644 node_params.conditional.size = count 1stuvw

645 node_params.conditional.ctx = self._get_conditional_context() 1stuvw

646 return GB_cond_with_params(self, node_params) 1stuvw

647  

648 def while_loop(self, condition: GraphCondition) -> GraphBuilder: 

649 """Adds a while loop and returns a new graph builder for it. 

650  

651 The resulting while loop graph will execute the branch repeatedly at runtime 

652 until the condition evaluates to false. 

653  

654 The new builder inherits work dependencies from the original builder. 

655  

656 Parameters 

657 ---------- 

658 condition : :class:`~graph.GraphCondition` 

659 The condition variable from :meth:`create_condition` controlling 

660 loop continuation. 

661  

662 Returns 

663 ------- 

664 graph_builder : :obj:`~graph.GraphBuilder` 

665 The newly created while loop graph builder. 

666  

667 """ 

668 GB_check_open(self) 1yzAB

669 if cy_driver_version() < (12, 3, 0): 1yzAB

670 raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop") 

671 if cy_binding_version() < (12, 3, 0): 1yzAB

672 raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop") 

673 if not isinstance(condition, GraphCondition): 1yzAB

674 raise TypeError( 

675 f"condition must be a GraphCondition object (from " 

676 f"GraphBuilder.create_condition()), got {type(condition).__name__}") 

677 node_params = driver.CUgraphNodeParams() 1yzAB

678 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1yzAB

679 node_params.conditional.handle = condition.handle 1yzAB

680 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1yzAB

681 node_params.conditional.size = 1 1yzAB

682 node_params.conditional.ctx = self._get_conditional_context() 1yzAB

683 return GB_cond_with_params(self, node_params)[0] 1yzAB

684  

685 def embed(self, GraphBuilder child): 

686 """Embed a previously-built :obj:`~graph.GraphBuilder` as a child node. 

687  

688 Parameters 

689 ---------- 

690 child : :obj:`~graph.GraphBuilder` 

691 The child graph builder. Must have finished building. 

692 """ 

693 GB_check_open(self) 1E

694 if child._state != CAPTURE_ENDED: 1E

695 raise ValueError("Child graph has not finished building.") 

696  

697 if not self.is_building: 1E

698 raise ValueError("Parent graph is not being built.") 

699  

700 stream_handle = self._stream.handle 1E

701 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1E

702 driver.cuStreamGetCaptureInfo(stream_handle) 1E

703 ) 

704  

705 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159 

706 # for rationale 

707 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1E

708 deps_info_update = [ 1E

709 [ 1E

710 handle_return( 1E

711 driver.cuGraphAddChildGraphNode( 1E

712 graph_out, *deps_info_trimmed, num_dependencies_out, as_py(child._h_graph) 1E

713 ) 

714 ) 

715 ] 

716 ] + [None] * (len(deps_info_out) - 1) 1E

717 handle_return( 1E

718 driver.cuStreamUpdateCaptureDependencies( 1E

719 stream_handle, 1E

720 *deps_info_update, # dependencies, edgeData 

721 1, 

722 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1E

723 ) 

724 ) 

725  

726 def callback(self, fn, *, user_data=None) -> None: 

727 """Add a host callback to the graph during stream capture. 

728  

729 The callback runs on the host CPU when the graph reaches this point 

730 in execution. Two modes are supported: 

731  

732 - **Python callable**: Pass any callable. The GIL is acquired 

733 automatically. The callable must take no arguments; use closures 

734 or ``functools.partial`` to bind state. 

735 - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance. 

736 The function receives a single ``void*`` argument (the 

737 ``user_data``). The caller must keep the ctypes wrapper alive 

738 for the lifetime of the graph. 

739  

740 .. warning:: 

741  

742 Callbacks must not call CUDA API functions. Doing so may 

743 deadlock or corrupt driver state. 

744  

745 Parameters 

746 ---------- 

747 fn : callable or ctypes function pointer 

748 The callback function. 

749 user_data : int or bytes-like, optional 

750 Only for ctypes function pointers. If ``int``, passed as a raw 

751 pointer (caller manages lifetime). If bytes-like, the data is 

752 copied and its lifetime is tied to the graph. 

753 """ 

754 GB_check_open(self) 1GH

755 cdef Stream stream = self._stream 1GH

756 cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) 1GH

757 cdef cydriver.CUstreamCaptureStatus capture_status 

758 cdef cydriver.CUgraph c_graph = NULL 1GH

759  

760 with nogil: 1GH

761 _get_capture_info(c_stream, &capture_status, &c_graph) 1GH

762  

763 if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1GH

764 raise RuntimeError("Cannot add callback when graph is not being built") 

765  

766 cdef cydriver.CUhostFn c_fn 

767 cdef void* c_user_data = NULL 1GH

768 _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) 1GH

769  

770 with nogil: 1GH

771 HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) 1GH

772  

773  

774cdef inline int GB_check_open(GraphBuilder gb) except -1: 

775 """Reject operations on a builder that has been closed. 

776  

777 A CLOSED builder has reset its stream and graph handles, so any method 

778 that dereferences them would read a null handle (or, for the cached 

779 Stream, a None typed as cdef Stream). Guarding here yields a clear error 

780 instead. 

781 """ 

782 if gb._state == CLOSED: 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

783 raise RuntimeError("Graph builder has been closed.") 1N

784 return 0 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

785  

786  

787cdef inline int GB_end_capture_if_needed(GraphBuilder gb, bint check_status) except -1 nogil: 

788 """End an in-progress capture if this builder owns it. 

789  

790 Only a CAPTURING PRIMARY or CONDITIONAL_BODY builder owns the live 

791 capture. A FORKED builder must not call cuStreamEndCapture: the driver 

792 requires forked streams to be joined first. 

793  

794 check_status=True checks the driver return (close()); False ignores it 

795 (__dealloc__). 

796 """ 

797 cdef cydriver.CUgraph c_graph 

798 cdef cydriver.CUresult err 

799 cdef cydriver.CUstream c_stream 

800 if gb._h_stream and gb._state == CAPTURING and gb._kind != FORKED: 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

801 c_stream = as_cu(gb._h_stream) 1N-

802 with nogil: 1N-

803 err = cydriver.cuStreamEndCapture(c_stream, &c_graph) 1N-

804 if check_status: 1N-

805 HANDLE_RETURN(err) 

806 return 0 2* G H 9 E U N , C - + D Z DbP 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B ' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J L 5 6 ) M a 7

807  

808  

809cdef inline GraphBuilder GB_init_forked(Stream stream, GraphHandle h_primary_graph): 

810 cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) 1NCDbcdefghijklmnopqrstuva

811 # A FORKED builder captures into the primary's CUgraph. It holds the 

812 # primary's GraphHandle so conditional bodies created on it (via 

813 # GB_init_conditional -> create_graph_handle_ref(cond_graph, parent._h_graph)) 

814 # have a valid parent handle to pin. 

815 gb._h_graph = h_primary_graph 1NCDbcdefghijklmnopqrstuva

816 gb._h_stream = stream._h_stream 1NCDbcdefghijklmnopqrstuva

817 gb._kind = FORKED 1NCDbcdefghijklmnopqrstuva

818 gb._state = CAPTURING 1NCDbcdefghijklmnopqrstuva

819 gb._stream = stream 1NCDbcdefghijklmnopqrstuva

820 return gb 1NCDbcdefghijklmnopqrstuva

821  

822  

823cdef inline GraphBuilder GB_init_conditional(Stream stream, cydriver.CUgraph cond_graph, GraphBuilder parent): 

824 cdef GraphBuilder gb = GraphBuilder.__new__(GraphBuilder) 1bcdefghijklmnopqxrstuvyzABwa

825 gb._h_graph = create_graph_handle_ref(cond_graph, parent._h_graph) 1bcdefghijklmnopqxrstuvyzABwa

826 gb._h_stream = stream._h_stream 1bcdefghijklmnopqxrstuvyzABwa

827 gb._kind = CONDITIONAL_BODY 1bcdefghijklmnopqxrstuvyzABwa

828 gb._state = CAPTURE_NOT_STARTED 1bcdefghijklmnopqxrstuvyzABwa

829 gb._stream = stream 1bcdefghijklmnopqxrstuvyzABwa

830 return gb 1bcdefghijklmnopqxrstuvyzABwa

831  

832  

833cdef inline int _get_capture_info( 

834 cydriver.CUstream stream, 

835 cydriver.CUstreamCaptureStatus* status, 

836 cydriver.CUgraph* graph) except?-1 nogil: 

837 """Thin wrapper around ``cuStreamGetCaptureInfo`` that papers over the 

838 CUDA 12 vs 13 signature change. 

839  

840 ``status`` must be non-NULL: the driver rejects ``captureStatus_out=NULL`` 

841 with ``CUDA_ERROR_INVALID_VALUE``. ``graph`` may be NULL when the caller 

842 does not need the graph handle. 

843 """ 

844 IF CUDA_CORE_BUILD_MAJOR >= 13: 

845 return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 1*GH9EUN,C-+DZP0ObcdefghijklmnopqxrstuvyzAB'!$(#%1QV2RWKFI3SX4TYwJL56)Ma7

846 stream, status, NULL, graph, NULL, NULL, NULL)) 

847 ELSE: 

848 return HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( 

849 stream, status, NULL, graph, NULL, NULL)) 

850  

851  

852cdef inline tuple GB_cond_with_params(GraphBuilder gb, node_params): 

853 status, _, graph, *deps_info, num_dependencies = handle_return( 1bcdefghijklmnopqxrstuvyzABwa

854 driver.cuStreamGetCaptureInfo(gb._stream.handle) 1bcdefghijklmnopqxrstuvyzABwa

855 ) 

856 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1bcdefghijklmnopqxrstuvyzABwa

857 raise RuntimeError("Cannot add conditional node when not actively capturing") 

858  

859 deps_info_update = [ 1bcdefghijklmnopqxrstuvyzABwa

860 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 1bcdefghijklmnopqxrstuvyzABwa

861 ] + [None] * (len(deps_info) - 1) 1bcdefghijklmnopqxrstuvyzABwa

862  

863 handle_return( 1bcdefghijklmnopqxrstuvyzABwa

864 driver.cuStreamUpdateCaptureDependencies( 1bcdefghijklmnopqxrstuvyzABwa

865 gb._stream.handle, 1bcdefghijklmnopqxrstuvyzABwa

866 *deps_info_update, # dependencies, edgeData 1bcdefghijklmnopqxrstuvyzABwa

867 1, # numDependencies 

868 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 1bcdefghijklmnopqxrstuvyzABwa

869 ) 

870 ) 

871  

872 return tuple( 1bcdefghijklmnopqxrstuvyzABwa

873 GB_init_conditional( 1bcdefghijklmnopqxrstuvyzABwa

874 gb._stream.device.create_stream(), 1bcdefghijklmnopqxrstuvyzABwa

875 <cydriver.CUgraph><intptr_t>int(node_params.conditional.phGraph_out[i]), 1bcdefghijklmnopqxrstuvyzABwa

876 gb, 1bcdefghijklmnopqxrstuvyzABwa

877 ) 

878 for i in range(node_params.conditional.size) 1bcdefghijklmnopqxrstuvyzABwa

879 ) 

880  

881  

882cdef class Graph: 

883 """An executable graph. 

884  

885 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

886 them with a specified dependency tree. It speeds up the workflow by combining the 

887 driver activities associated with CUDA kernel launches and CUDA API calls. 

888  

889 Graphs must be built using a :obj:`~graph.GraphBuilder` object. 

890  

891 """ 

892  

893 def __init__(self): 

894 raise RuntimeError("directly constructing a Graph instance is not supported") 

895  

896 @staticmethod 

897 cdef Graph _init(cydriver.CUgraphExec graph_exec): 

898 cdef Graph self = Graph.__new__(Graph) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

899 self._h_graph_exec = create_graph_exec_handle(graph_exec) 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

900 return self 2G H 9 E U C D Z P 0 O b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | hbybzbibAbBbCbnbjbkbobpbqbrblbsbtbubvbwbxb' ! $ ( # % 1 Q V 2 R W K F I 3 S X 4 T Y w J ^ L 5 6 M 7

901  

902 def close(self) -> None: 

903 """Destroy the graph.""" 

904 self._h_graph_exec.reset() 1U0M

905  

906 @property 

907 def handle(self) -> driver.CUgraphExec: 

908 """Return the underlying ``CUgraphExec`` object. 

909  

910 .. caution:: 

911  

912 This handle is a Python object. To get the memory address of the underlying C 

913 handle, call ``int()`` on the returned object. 

914  

915 """ 

916 return as_py(self._h_graph_exec) 1U

917  

918 def update(self, source: "GraphBuilder | GraphDefinition") -> None: 

919 """Update the graph using a new graph definition. 

920  

921 The topology of the provided source must be identical to this graph. 

922  

923 Parameters 

924 ---------- 

925 source : :obj:`~graph.GraphBuilder` or :obj:`~graph.GraphDefinition` 

926 The graph definition to update from. A GraphBuilder must have 

927 finished building. 

928  

929 """ 

930 from cuda.core.graph import GraphDefinition 1OwJ^L56

931  

932 cdef cydriver.CUgraph cu_graph 

933 cdef cydriver.CUgraphExec cu_exec = as_cu(self._h_graph_exec) 1OwJ^L56

934  

935 if isinstance(source, GraphBuilder): 1OwJ^L56

936 if (<GraphBuilder>source)._state == CLOSED: 1OwJL5

937 raise ValueError("Source graph builder has been closed.") 1O

938 if (<GraphBuilder>source)._state != CAPTURE_ENDED: 1wJL5

939 raise ValueError("Graph has not finished building.") 15

940 cu_graph = as_cu((<GraphBuilder>source)._h_graph) 1wJL

941 elif isinstance(source, GraphDefinition): 1^6

942 cu_graph = <cydriver.CUgraph><intptr_t>int(source.handle) 1^

943 else: 

944 raise TypeError( 16

945 f"expected GraphBuilder or GraphDefinition, got {type(source).__name__}") 16

946  

947 cdef cydriver.CUgraphExecUpdateResultInfo result_info 

948 cdef cydriver.CUresult err 

949 with nogil: 1wJ^L

950 err = cydriver.cuGraphExecUpdate(cu_exec, cu_graph, &result_info) 1wJ^L

951 if err == cydriver.CUresult.CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: 1wJ^L

952 reason = driver.CUgraphExecUpdateResult(result_info.result) 1L

953 msg = f"Graph update failed: {reason.__doc__.strip()} ({reason.name})" 1L

954 raise CUDAError(msg) 1L

955 HANDLE_RETURN(err) 1wJ^

956  

957 def upload(self, stream: Stream) -> None: 

958 """Uploads the graph in a stream. 

959  

960 Parameters 

961 ---------- 

962 stream : :obj:`~_stream.Stream` 

963 The stream in which to upload the graph 

964  

965 """ 

966 cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) 2C P } . ~ / abbbcbdb: eb; = ? fb@ [ gb] jbkblb1 Q V 2 R W K F I 3 S X 4 T Y

967 cdef cydriver.CUstream c_stream = <cydriver.CUstream><intptr_t>int(stream.handle) 2C P } . ~ / abbbcbdb: eb; = ? fb@ [ gb] jbkblb1 Q V 2 R W K F I 3 S X 4 T Y

968 with nogil: 2C P } . ~ / abbbcbdb: eb; = ? fb@ [ gb] jbkblb1 Q V 2 R W K F I 3 S X 4 T Y

969 HANDLE_RETURN(cydriver.cuGraphUpload(c_exec, c_stream)) 2C P } . ~ / abbbcbdb: eb; = ? fb@ [ gb] jbkblb1 Q V 2 R W K F I 3 S X 4 T Y

970  

971 def launch(self, stream: Stream) -> None: 

972 """Launches the graph in a stream. 

973  

974 Parameters 

975 ---------- 

976 stream : :obj:`~_stream.Stream` 

977 The stream in which to launch the graph. 

978  

979 """ 

980 cdef cydriver.CUgraphExec c_exec = as_cu(self._h_graph_exec) 2G H E C Z P b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | nbjbkbobpbqbrblbsbtbubvbwbxb1 Q V 2 R W K F I 3 S X 4 T Y w J ^ 7

981 cdef cydriver.CUstream c_stream = <cydriver.CUstream><intptr_t>int(stream.handle) 2G H E C Z P b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | nbjbkbobpbqbrblbsbtbubvbwbxb1 Q V 2 R W K F I 3 S X 4 T Y w J ^ 7

982 with nogil: 2G H E C Z P b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | nbjbkbobpbqbrblbsbtbubvbwbxb1 Q V 2 R W K F I 3 S X 4 T Y w J ^ 7

983 HANDLE_RETURN(cydriver.cuGraphLaunch(c_exec, c_stream)) 2G H E C Z P b c d e f g h i j k l m n o p q x r s t u v y z A B } . ~ / _ abbbcbdb: eb; ` = ? fb@ { [ gb] | nbjbkbobpbqbrblbsbtbubvbwbxb1 Q V 2 R W K F I 3 S X 4 T Y w J ^ 7