Coverage for cuda / core / _graph / __init__.py: 89.39%

330 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-25 01:07 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5from __future__ import annotations 

6 

7import weakref 

8from dataclasses import dataclass 

9from typing import TYPE_CHECKING 

10 

11if TYPE_CHECKING: 

12 from cuda.core._stream import Stream 

13 

14from cuda.core._utils.cuda_utils import ( 

15 driver, 

16 get_binding_version, 

17 handle_return, 

18) 

19 

20_inited = False 

21_driver_ver = None 

22 

23 

24def _lazy_init(): 

25 global _inited 

26 if _inited: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

27 return 1EJwv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

28 

29 global _py_major_minor, _driver_ver 

30 # binding availability depends on cuda-python version 

31 _py_major_minor = get_binding_version() 1I

32 _driver_ver = handle_return(driver.cuDriverGetVersion()) 1I

33 _inited = True 1I

34 

35 

36@dataclass 

37class GraphDebugPrintOptions: 

38 """Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()` 

39 

40 Attributes 

41 ---------- 

42 verbose : bool 

43 Output all debug data as if every debug flag is enabled (Default to False) 

44 runtime_types : bool 

45 Use CUDA Runtime structures for output (Default to False) 

46 kernel_node_params : bool 

47 Adds kernel parameter values to output (Default to False) 

48 memcpy_node_params : bool 

49 Adds memcpy parameter values to output (Default to False) 

50 memset_node_params : bool 

51 Adds memset parameter values to output (Default to False) 

52 host_node_params : bool 

53 Adds host parameter values to output (Default to False) 

54 event_node_params : bool 

55 Adds event parameter values to output (Default to False) 

56 ext_semas_signal_node_params : bool 

57 Adds external semaphore signal parameter values to output (Default to False) 

58 ext_semas_wait_node_params : bool 

59 Adds external semaphore wait parameter values to output (Default to False) 

60 kernel_node_attributes : bool 

61 Adds kernel node attributes to output (Default to False) 

62 handles : bool 

63 Adds node handles and every kernel function handle to output (Default to False) 

64 mem_alloc_node_params : bool 

65 Adds memory alloc parameter values to output (Default to False) 

66 mem_free_node_params : bool 

67 Adds memory free parameter values to output (Default to False) 

68 batch_mem_op_node_params : bool 

69 Adds batch mem op parameter values to output (Default to False) 

70 extra_topo_info : bool 

71 Adds edge numbering information (Default to False) 

72 conditional_node_params : bool 

73 Adds conditional node parameter values to output (Default to False) 

74 

75 """ 

76 

77 verbose: bool = False 

78 runtime_types: bool = False 

79 kernel_node_params: bool = False 

80 memcpy_node_params: bool = False 

81 memset_node_params: bool = False 

82 host_node_params: bool = False 

83 event_node_params: bool = False 

84 ext_semas_signal_node_params: bool = False 

85 ext_semas_wait_node_params: bool = False 

86 kernel_node_attributes: bool = False 

87 handles: bool = False 

88 mem_alloc_node_params: bool = False 

89 mem_free_node_params: bool = False 

90 batch_mem_op_node_params: bool = False 

91 extra_topo_info: bool = False 

92 conditional_node_params: bool = False 

93 

94 def _to_flags(self) -> int: 

95 """Convert options to CUDA driver API flags (internal use).""" 

96 flags = 0 1.a

97 if self.verbose: 1.a

98 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE 1.a

99 if self.runtime_types: 1.a

100 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES 1a

101 if self.kernel_node_params: 1.a

102 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS 1a

103 if self.memcpy_node_params: 1.a

104 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS 1a

105 if self.memset_node_params: 1.a

106 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS 1a

107 if self.host_node_params: 1.a

108 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS 1a

109 if self.event_node_params: 1.a

110 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS 1a

111 if self.ext_semas_signal_node_params: 1.a

112 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS 1a

113 if self.ext_semas_wait_node_params: 1.a

114 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS 1a

115 if self.kernel_node_attributes: 1.a

116 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES 1a

117 if self.handles: 1.a

118 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES 1.a

119 if self.mem_alloc_node_params: 1.a

120 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS 1a

121 if self.mem_free_node_params: 1.a

122 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS 1a

123 if self.batch_mem_op_node_params: 1.a

124 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS 1a

125 if self.extra_topo_info: 1.a

126 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO 1a

127 if self.conditional_node_params: 1.a

128 flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS 1a

129 return flags 1.a

130 

131 

132@dataclass 

133class GraphCompleteOptions: 

134 """Customizable options for :obj:`_graph.GraphBuilder.complete()` 

135 

136 Attributes 

137 ---------- 

138 auto_free_on_launch : bool, optional 

139 Automatically free memory allocated in a graph before relaunching. (Default to False) 

140 upload_stream : Stream, optional 

141 Stream to use to automatically upload the graph after completion. (Default to None) 

142 device_launch : bool, optional 

143 Configure the graph to be launchable from the device. This flag can only 

144 be used on platforms which support unified addressing. This flag cannot be 

145 used in conjunction with auto_free_on_launch. (Default to False) 

146 use_node_priority : bool, optional 

147 Run the graph using the per-node priority attributes rather than the 

148 priority of the stream it is launched into. (Default to False) 

149 

150 """ 

151 

152 auto_free_on_launch: bool = False 

153 upload_stream: Stream | None = None 

154 device_launch: bool = False 

155 use_node_priority: bool = False 

156 

157 

158def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph: 

159 params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

160 if options: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

161 flags = 0 1HFG6789!#$%'()*+,-D

162 if options.auto_free_on_launch: 1HFG6789!#$%'()*+,-D

163 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH 1HFG69%),-D

164 if options.upload_stream: 1HFG6789!#$%'()*+,-D

165 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD 18#(+D

166 params.hUploadStream = options.upload_stream.handle 18#(+D

167 if options.device_launch: 1HFG6789!#$%'()*+,-D

168 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH 1$D

169 if options.use_node_priority: 1HFG6789!#$%'()*+,-D

170 flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY 17!'*,-D

171 params.flags = flags 1HFG6789!#$%'()*+,-D

172 

173 graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

174 if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

175 raise RuntimeError( 

176 "Instantiation failed for an unexpected reason which is described in the return value of the function." 

177 ) 

178 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

179 raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.") 

180 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

181 raise RuntimeError( 

182 "Instantiation for device launch failed because the graph contained an unsupported operation." 

183 ) 

184 elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

185 raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.") 

186 elif ( 

187 _py_major_minor >= (12, 8) 

188 and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED 

189 ): 

190 raise RuntimeError("One or more conditional handles are not associated with conditional builders.") 

191 elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D

192 raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") 

193 return graph 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D

194 

195 

196class GraphBuilder: 

197 """Represents a graph under construction. 

198 

199 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

200 them with a specified dependency tree. It speeds up the workflow by combining the 

201 driver activities associated with CUDA kernel launches and CUDA API calls. 

202 

203 Directly creating a :obj:`~_graph.GraphBuilder` is not supported due 

204 to ambiguity. New graph builders should instead be created through a 

205 :obj:`~_device.Device`, or a :obj:`~_stream.stream` object. 

206 

207 """ 

208 

209 class _MembersNeededForFinalize: 

210 __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") 

211 

212 def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required): 

213 self.stream = stream_obj 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

214 self.is_stream_owner = is_stream_owner 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

215 self.graph = None 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

216 self.conditional_graph = conditional_graph 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

217 self.is_join_required = is_join_required 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

218 weakref.finalize(graph_builder_obj, self.close) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

219 

220 def close(self): 

221 if self.stream: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

222 if not self.is_join_required: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

223 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

224 if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

225 # Note how this condition only occures for the primary graph builder 

226 # This is because calling cuStreamEndCapture streams that were split off of the primary 

227 # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. 

228 # Therefore, it is currently a requirement that users join all split graph builders 

229 # before a graph builder can be clearly destroyed. 

230 handle_return(driver.cuStreamEndCapture(self.stream.handle)) 

231 if self.is_stream_owner: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

232 self.stream.close() 1EJwIv5xK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

233 self.stream = None 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

234 if self.graph: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

235 handle_return(driver.cuGraphDestroy(self.graph)) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

236 self.graph = None 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

237 self.conditional_graph = None 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

238 

239 __slots__ = ("__weakref__", "_building_ended", "_mnff") 

240 

241 def __init__(self): 

242 raise NotImplementedError( 

243 "directly creating a Graph object can be ambiguous. Please either " 

244 "call Device.create_graph_builder() or stream.create_graph_builder()" 

245 ) 

246 

247 @classmethod 

248 def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): 

249 self = cls.__new__(cls) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

250 _lazy_init() 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

251 self._mnff = GraphBuilder._MembersNeededForFinalize( 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

252 self, stream, is_stream_owner, conditional_graph, is_join_required 

253 ) 

254 

255 self._building_ended = False 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

256 return self 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

257 

258 @property 

259 def stream(self) -> Stream: 

260 """Returns the stream associated with the graph builder.""" 

261 return self._mnff.stream 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

262 

263 @property 

264 def is_join_required(self) -> bool: 

265 """Returns True if this graph builder must be joined before building is ended.""" 

266 return self._mnff.is_join_required 1vxbcdefghijklmnopqrstua

267 

268 def begin_building(self, mode="relaxed") -> GraphBuilder: 

269 """Begins the building process. 

270 

271 Build `mode` for controlling interaction with other API calls must be one of the following: 

272 

273 - `global` : Prohibit potentially unsafe operations across all streams in the process. 

274 - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread. 

275 - `relaxed` : The local thread is not prohibited from potentially unsafe operations. 

276 

277 Parameters 

278 ---------- 

279 mode : str, optional 

280 Build mode to control the interaction with other API calls that are porentially unsafe. 

281 Default set to use relaxed. 

282 

283 """ 

284 if self._building_ended: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

285 raise RuntimeError("Cannot resume building after building has ended.") 1P

286 if mode not in ("global", "thread_local", "relaxed"): 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

287 raise ValueError(f"Unsupported build mode: {mode}") 14

288 if mode == "global": 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

289 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL 123WXHYZ4

290 elif mode == "thread_local": 1EJwIv5xPKU0V1LQMRFGNSOTbcdefghijklmnopqrstuyzAB4Da

291 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL 101QRGST4

292 elif mode == "relaxed": 1EJwIv5xPKUVLMFNObcdefghijklmnopqrstuyzAB4Da

293 capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED 1EJwIv5xPKUVLMFNObcdefghijklmnopqrstuyzAB4Da

294 else: 

295 raise ValueError(f"Unsupported build mode: {mode}") 

296 

297 if self._mnff.conditional_graph: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

298 handle_return( 1wbcdefghijklmnopqrstuyzABa

299 driver.cuStreamBeginCaptureToGraph( 

300 self._mnff.stream.handle, 

301 self._mnff.conditional_graph, 

302 None, # dependencies 

303 None, # dependencyData 

304 0, # numDependencies 

305 capture_mode, 

306 ) 

307 ) 

308 else: 

309 handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

310 return self 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

311 

312 @property 

313 def is_building(self) -> bool: 

314 """Returns True if the graph builder is currently building.""" 

315 capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

316 if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

317 return False 15

318 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

319 return True 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

320 elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: 

321 raise RuntimeError( 

322 "Build process encountered an error and has been invalidated. Build process must now be ended." 

323 ) 

324 else: 

325 raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") 

326 

327 def end_building(self) -> GraphBuilder: 

328 """Ends the building process.""" 

329 if not self.is_building: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

330 raise RuntimeError("Graph builder is not building.") 

331 if self._mnff.conditional_graph: 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

332 self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1wbcdefghijklmnopqrstuyzABa

333 else: 

334 self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

335 

336 # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to 

337 # resume the build process after the first call to end_building() 

338 self._building_ended = True 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

339 return self 1EJwIv5xPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB4Da

340 

341 def complete(self, options: GraphCompleteOptions | None = None) -> Graph: 

342 """Completes the graph builder and returns the built :obj:`~_graph.Graph` object. 

343 

344 Parameters 

345 ---------- 

346 options : :obj:`~_graph.GraphCompleteOptions`, optional 

347 Customizable dataclass for the graph builder completion options. 

348 

349 Returns 

350 ------- 

351 graph : :obj:`~_graph.Graph` 

352 The newly built graph. 

353 

354 """ 

355 if not self._building_ended: 1EJwIvxPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzABD

356 raise RuntimeError("Graph has not finished building.") 1I

357 

358 return _instantiate_graph(self._mnff.graph, options) 1EJwIvxPK2U03V1WLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzABD

359 

360 def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): 

361 """Generates a DOT debug file for the graph builder. 

362 

363 Parameters 

364 ---------- 

365 path : str 

366 File path to use for writting debug DOT output 

367 options : :obj:`~_graph.GraphDebugPrintOptions`, optional 

368 Customizable dataclass for the debug print options. 

369 

370 """ 

371 if not self._building_ended: 1a

372 raise RuntimeError("Graph has not finished building.") 

373 flags = options._to_flags() if options else 0 1a

374 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) 1a

375 

376 def split(self, count: int) -> tuple[GraphBuilder, ...]: 

377 """Splits the original graph builder into multiple graph builders. 

378 

379 The new builders inherit work dependencies from the original builder. 

380 The original builder is reused for the split and is returned first in the tuple. 

381 

382 Parameters 

383 ---------- 

384 count : int 

385 The number of graph builders to split the graph builder into. 

386 

387 Returns 

388 ------- 

389 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] 

390 A tuple of split graph builders. The first graph builder in the tuple 

391 is always the original graph builder. 

392 

393 """ 

394 if count < 2: 1vxbcdefghijklmnopqrstua

395 raise ValueError(f"Invalid split count: expecting >= 2, got {count}") 1v

396 

397 event = self._mnff.stream.record() 1vxbcdefghijklmnopqrstua

398 result = [self] 1vxbcdefghijklmnopqrstua

399 for i in range(count - 1): 1vxbcdefghijklmnopqrstua

400 stream = self._mnff.stream.device.create_stream() 1vxbcdefghijklmnopqrstua

401 stream.wait(event) 1vxbcdefghijklmnopqrstua

402 result.append( 1vxbcdefghijklmnopqrstua

403 GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) 

404 ) 

405 event.close() 1vxbcdefghijklmnopqrstua

406 return result 1vxbcdefghijklmnopqrstua

407 

408 @staticmethod 

409 def join(*graph_builders) -> GraphBuilder: 

410 """Joins multiple graph builders into a single graph builder. 

411 

412 The returned builder inherits work dependencies from the provided builders. 

413 

414 Parameters 

415 ---------- 

416 *graph_builders : :obj:`~_graph.GraphBuilder` 

417 The graph builders to join. 

418 

419 Returns 

420 ------- 

421 graph_builder : :obj:`~_graph.GraphBuilder` 

422 The newly joined graph builder. 

423 

424 """ 

425 if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): 1vxbcdefghijklmnopqrstua

426 raise TypeError("All arguments must be GraphBuilder instances") 

427 if len(graph_builders) < 2: 1vxbcdefghijklmnopqrstua

428 raise ValueError("Must join with at least two graph builders") 1v

429 

430 # Discover the root builder others should join 

431 root_idx = 0 1vxbcdefghijklmnopqrstua

432 for i, builder in enumerate(graph_builders): 1vxbcdefghijklmnopqrstua

433 if not builder.is_join_required: 1vxbcdefghijklmnopqrstua

434 root_idx = i 1vxbcdefghijklmnopqrstua

435 break 1vxbcdefghijklmnopqrstua

436 

437 # Join all onto the root builder 

438 root_bdr = graph_builders[root_idx] 1vxbcdefghijklmnopqrstua

439 for idx, builder in enumerate(graph_builders): 1vxbcdefghijklmnopqrstua

440 if idx == root_idx: 1vxbcdefghijklmnopqrstua

441 continue 1vxbcdefghijklmnopqrstua

442 root_bdr.stream.wait(builder.stream) 1vxbcdefghijklmnopqrstua

443 builder.close() 1vxbcdefghijklmnopqrstua

444 

445 return root_bdr 1vxbcdefghijklmnopqrstua

446 

447 def __cuda_stream__(self) -> tuple[int, int]: 

448 """Return an instance of a __cuda_stream__ protocol.""" 

449 return self.stream.__cuda_stream__() 

450 

451 def _get_conditional_context(self) -> driver.CUcontext: 

452 return self._mnff.stream.context.handle 1wbcdefghijklmnopqrstuyzABa

453 

454 def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle: 

455 """Creates a conditional handle for the graph builder. 

456 

457 Parameters 

458 ---------- 

459 default_value : int, optional 

460 The default value to assign to the conditional handle. 

461 

462 Returns 

463 ------- 

464 handle : driver.CUgraphConditionalHandle 

465 The newly created conditional handle. 

466 

467 """ 

468 if _driver_ver < 12030: 1wbcdefghijklmnopqrstuyzABa

469 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles") 

470 if _py_major_minor < (12, 3): 1wbcdefghijklmnopqrstuyzABa

471 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles") 

472 if default_value is not None: 1wbcdefghijklmnopqrstuyzABa

473 flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT 1wyzAB

474 else: 

475 default_value = 0 1bcdefghijklmnopqrstua

476 flags = 0 1bcdefghijklmnopqrstua

477 

478 status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) 1wbcdefghijklmnopqrstuyzABa

479 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1wbcdefghijklmnopqrstuyzABa

480 raise RuntimeError("Cannot create a conditional handle when graph is not being built") 

481 

482 return handle_return( 1wbcdefghijklmnopqrstuyzABa

483 driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) 

484 ) 

485 

486 def _cond_with_params(self, node_params) -> GraphBuilder: 

487 # Get current capture info to ensure we're in a valid state 

488 status, _, graph, *deps_info, num_dependencies = handle_return( 1wbcdefghijklmnopqrstuyzABa

489 driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) 

490 ) 

491 if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: 1wbcdefghijklmnopqrstuyzABa

492 raise RuntimeError("Cannot add conditional node when not actively capturing") 

493 

494 # Add the conditional node to the graph 

495 deps_info_update = [ 1wbcdefghijklmnopqrstuyzABa

496 [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] 

497 ] + [None] * (len(deps_info) - 1) 

498 

499 # Update the stream's capture dependencies 

500 handle_return( 1wbcdefghijklmnopqrstuyzABa

501 driver.cuStreamUpdateCaptureDependencies( 

502 self._mnff.stream.handle, 

503 *deps_info_update, # dependencies, edgeData 

504 1, # numDependencies 

505 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 

506 ) 

507 ) 

508 

509 # Create new graph builders for each condition 

510 return tuple( 1wbcdefghijklmnopqrstuyzABa

511 [ 

512 GraphBuilder._init( 

513 stream=self._mnff.stream.device.create_stream(), 

514 is_stream_owner=True, 

515 conditional_graph=node_params.conditional.phGraph_out[i], 

516 is_join_required=False, 

517 ) 

518 for i in range(node_params.conditional.size) 

519 ] 

520 ) 

521 

522 def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: 

523 """Adds an if condition branch and returns a new graph builder for it. 

524 

525 The resulting if graph will only execute the branch if the conditional 

526 handle evaluates to true at runtime. 

527 

528 The new builder inherits work dependencies from the original builder. 

529 

530 Parameters 

531 ---------- 

532 handle : driver.CUgraphConditionalHandle 

533 The handle to use for the if conditional. 

534 

535 Returns 

536 ------- 

537 graph_builder : :obj:`~_graph.GraphBuilder` 

538 The newly created conditional graph builder. 

539 

540 """ 

541 if _driver_ver < 12030: 1bcdefghia

542 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if") 

543 if _py_major_minor < (12, 3): 1bcdefghia

544 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if") 

545 node_params = driver.CUgraphNodeParams() 1bcdefghia

546 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1bcdefghia

547 node_params.conditional.handle = handle 1bcdefghia

548 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1bcdefghia

549 node_params.conditional.size = 1 1bcdefghia

550 node_params.conditional.ctx = self._get_conditional_context() 1bcdefghia

551 return self._cond_with_params(node_params)[0] 1bcdefghia

552 

553 def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]: 

554 """Adds an if-else condition branch and returns new graph builders for both branches. 

555 

556 The resulting if graph will execute the branch if the conditional handle 

557 evaluates to true at runtime, otherwise the else branch will execute. 

558 

559 The new builders inherit work dependencies from the original builder. 

560 

561 Parameters 

562 ---------- 

563 handle : driver.CUgraphConditionalHandle 

564 The handle to use for the if-else conditional. 

565 

566 Returns 

567 ------- 

568 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, :obj:`~_graph.GraphBuilder`] 

569 A tuple of two new graph builders, one for the if branch and one for the else branch. 

570 

571 """ 

572 if _driver_ver < 12080: 1jklmnopq

573 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else") 

574 if _py_major_minor < (12, 8): 1jklmnopq

575 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else") 

576 node_params = driver.CUgraphNodeParams() 1jklmnopq

577 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1jklmnopq

578 node_params.conditional.handle = handle 1jklmnopq

579 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF 1jklmnopq

580 node_params.conditional.size = 2 1jklmnopq

581 node_params.conditional.ctx = self._get_conditional_context() 1jklmnopq

582 return self._cond_with_params(node_params) 1jklmnopq

583 

584 def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]: 

585 """Adds a switch condition branch and returns new graph builders for all cases. 

586 

587 The resulting switch graph will execute the branch that matches the 

588 case index of the conditional handle at runtime. If no match is found, no branch 

589 will be executed. 

590 

591 The new builders inherit work dependencies from the original builder. 

592 

593 Parameters 

594 ---------- 

595 handle : driver.CUgraphConditionalHandle 

596 The handle to use for the switch conditional. 

597 count : int 

598 The number of cases to add to the switch conditional. 

599 

600 Returns 

601 ------- 

602 graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] 

603 A tuple of new graph builders, one for each branch. 

604 

605 """ 

606 if _driver_ver < 12080: 1wrstu

607 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch") 

608 if _py_major_minor < (12, 8): 1wrstu

609 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch") 

610 node_params = driver.CUgraphNodeParams() 1wrstu

611 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1wrstu

612 node_params.conditional.handle = handle 1wrstu

613 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH 1wrstu

614 node_params.conditional.size = count 1wrstu

615 node_params.conditional.ctx = self._get_conditional_context() 1wrstu

616 return self._cond_with_params(node_params) 1wrstu

617 

618 def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: 

619 """Adds a while loop and returns a new graph builder for it. 

620 

621 The resulting while loop graph will execute the branch repeatedly at runtime 

622 until the conditional handle evaluates to false. 

623 

624 The new builder inherits work dependencies from the original builder. 

625 

626 Parameters 

627 ---------- 

628 handle : driver.CUgraphConditionalHandle 

629 The handle to use for the while loop. 

630 

631 Returns 

632 ------- 

633 graph_builder : :obj:`~_graph.GraphBuilder` 

634 The newly created while loop graph builder. 

635 

636 """ 

637 if _driver_ver < 12030: 1yzAB

638 raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop") 

639 if _py_major_minor < (12, 3): 1yzAB

640 raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop") 

641 node_params = driver.CUgraphNodeParams() 1yzAB

642 node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL 1yzAB

643 node_params.conditional.handle = handle 1yzAB

644 node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE 1yzAB

645 node_params.conditional.size = 1 1yzAB

646 node_params.conditional.ctx = self._get_conditional_context() 1yzAB

647 return self._cond_with_params(node_params)[0] 1yzAB

648 

649 def close(self): 

650 """Destroy the graph builder. 

651 

652 Closes the associated stream if we own it. Borrowed stream 

653 object will instead have their references released. 

654 

655 """ 

656 self._mnff.close() 1Jvxbcdefghijklmnopqrstua

657 

658 def add_child(self, child_graph: GraphBuilder): 

659 """Adds the child :obj:`~_graph.GraphBuilder` builder into self. 

660 

661 The child graph builder will be added as a child node to the parent graph builder. 

662 

663 Parameters 

664 ---------- 

665 child_graph : :obj:`~_graph.GraphBuilder` 

666 The child graph builder. Must have finished building. 

667 """ 

668 if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): 1E

669 raise NotImplementedError( 

670 f"Launching child graphs is not implemented for versions older than CUDA 12." 

671 f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}" 

672 ) 

673 

674 if not child_graph._building_ended: 1E

675 raise ValueError("Child graph has not finished building.") 

676 

677 if not self.is_building: 1E

678 raise ValueError("Parent graph is not being built.") 

679 

680 stream_handle = self._mnff.stream.handle 1E

681 _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( 1E

682 driver.cuStreamGetCaptureInfo(stream_handle) 

683 ) 

684 

685 # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159 

686 # for rationale 

687 deps_info_trimmed = deps_info_out[:num_dependencies_out] 1E

688 deps_info_update = [ 1E

689 [ 

690 handle_return( 

691 driver.cuGraphAddChildGraphNode( 

692 graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph 

693 ) 

694 ) 

695 ] 

696 ] + [None] * (len(deps_info_out) - 1) 

697 handle_return( 1E

698 driver.cuStreamUpdateCaptureDependencies( 

699 stream_handle, 

700 *deps_info_update, # dependencies, edgeData 

701 1, 

702 driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, 

703 ) 

704 ) 

705 

706 

707class Graph: 

708 """Represents an executable graph. 

709 

710 A graph groups a set of CUDA kernels and other CUDA operations together and executes 

711 them with a specified dependency tree. It speeds up the workflow by combining the 

712 driver activities associated with CUDA kernel launches and CUDA API calls. 

713 

714 Graphs must be built using a :obj:`~_graph.GraphBuilder` object. 

715 

716 """ 

717 

718 class _MembersNeededForFinalize: 

719 __slots__ = "graph" 

720 

721 def __init__(self, graph_obj, graph): 

722 self.graph = graph 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

723 weakref.finalize(graph_obj, self.close) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

724 

725 def close(self): 

726 if self.graph: 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D gb

727 handle_return(driver.cuGraphExecDestroy(self.graph)) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D gb

728 self.graph = None 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdb_ ` | } ~ { D gb

729 

730 __slots__ = ("__weakref__", "_mnff") 

731 

732 def __init__(self): 

733 raise RuntimeError("directly constructing a Graph instance is not supported") 

734 

735 @classmethod 

736 def _init(cls, graph): 

737 self = cls.__new__(cls) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

738 self._mnff = Graph._MembersNeededForFinalize(self, graph) 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

739 return self 2E J w I v x P K 2 U 0 3 V 1 W L Q X M R H F G Y N S Z O T b c d e f g h i j k l m n o p q r s t u y z A B / 6 : 7 8 ; = ? @ 9 [ ! # $ % ] ' ( ) ^ * + , abbb- cbdbebfb_ ` | } ~ { D

740 

741 def close(self): 

742 """Destroy the graph.""" 

743 self._mnff.close() 1JD

744 

745 @property 

746 def handle(self) -> driver.CUgraphExec: 

747 """Return the underlying ``CUgraphExec`` object. 

748 

749 .. caution:: 

750 

751 This handle is a Python object. To get the memory address of the underlying C 

752 handle, call ``int()`` on the returned object. 

753 

754 """ 

755 return self._mnff.graph 

756 

757 def update(self, builder: GraphBuilder): 

758 """Update the graph using new build configuration from the builder. 

759 

760 The topology of the provided builder must be identical to this graph. 

761 

762 Parameters 

763 ---------- 

764 builder : :obj:`~_graph.GraphBuilder` 

765 The builder to update the graph with. 

766 

767 """ 

768 if not builder._building_ended: 1w

769 raise ValueError("Graph has not finished building.") 

770 

771 # Update the graph with the new nodes from the builder 

772 exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph)) 1w

773 if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS: 1w

774 raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}") 

775 

776 def upload(self, stream: Stream): 

777 """Uploads the graph in a stream. 

778 

779 Parameters 

780 ---------- 

781 stream : :obj:`~_stream.Stream` 

782 The stream in which to upload the graph 

783 

784 """ 

785 handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) 1vKWLQXMRHFGYNSZOT/6:7;=?@9[!$%]')^*_`{

786 

787 def launch(self, stream: Stream): 

788 """Launches the graph in a stream. 

789 

790 Parameters 

791 ---------- 

792 stream : :obj:`~_stream.Stream` 

793 The stream in which to launch the graph 

794 

795 """ 

796 handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) 1EwvPKWLQXMRHFGYNSZOTbcdefghijklmnopqrstuyzAB/6:78;=?@9[!#$%]'()^*+_`|}~{