Coverage for cuda / core / graph / _adjacency_set_proxy.pyx: 86.99%

146 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-29 01:27 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5"""Mutable-set proxy for graph node predecessors and successors.""" 

6  

7from libc.stddef cimport size_t 

8from libcpp.vector cimport vector 

9from cuda.bindings cimport cydriver 

10from cuda.core.graph._graph_node cimport GraphNode 

11from cuda.core._resource_handles cimport ( 

12 GraphHandle, 

13 GraphNodeHandle, 

14 as_cu, 

15 graph_node_get_graph, 

16) 

17from cuda.core._utils.cuda_utils cimport HANDLE_RETURN 

18from collections.abc import MutableSet 

19  

20  

21# ---- Python MutableSet wrapper ---------------------------------------------- 

22  

23class AdjacencySetProxy(MutableSet): 

24 """Mutable set proxy for a node's predecessors or successors. Mutations 

25 write through to the underlying CUDA graph.""" 

26  

27 __slots__ = ("_core",) 

28  

29 def __init__(self, node, bint is_fwd): 

30 self._core = _AdjacencySetCore(node, is_fwd) 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl

31  

32 # Used by operators such as &|^ to create non-proxy views when needed. 

33 @classmethod 

34 def _from_iterable(cls, it): 

35 return set(it) 1a

36  

37 # --- abstract methods required by MutableSet --- 

38  

39 def __contains__(self, x): 

40 if not isinstance(x, GraphNode): 1nodmjfgacbekhl

41 return False 

42 return (<_AdjacencySetCore>self._core).contains(<GraphNode>x) 1nodmjfgacbekhl

43  

44 def __iter__(self): 

45 return iter((<_AdjacencySetCore>self._core).query()) 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYacbh

46  

47 def __len__(self): 

48 return (<_AdjacencySetCore>self._core).count() 1Zacbh

49  

50 def add(self, value): 

51 if not isinstance(value, GraphNode): 1jg0acekhl

52 raise TypeError( 10

53 f"expected GraphNode, got {type(value).__name__}") 10

54 if value in self: 1jgacekhl

55 return 1a

56 (<_AdjacencySetCore>self._core).add_edge(<GraphNode>value) 1jgacekhl

57  

58 def discard(self, value): 

59 if not isinstance(value, GraphNode): 1gac

60 return 

61 if value not in self: 1gac

62 return 1a

63 (<_AdjacencySetCore>self._core).remove_edge(<GraphNode>value) 1gac

64  

65 # --- override for bulk efficiency --- 

66  

67 def clear(self): 

68 """Remove all edges in a single driver call.""" 

69 members = (<_AdjacencySetCore>self._core).query() 1dfabe

70 if members: 1dfabe

71 (<_AdjacencySetCore>self._core).remove_edges(members) 1abe

72  

73 def __isub__(self, it): 

74 """Remove edges to all nodes in *it* in a single driver call.""" 

75 if it is self: 1a

76 self.clear() 

77 else: 

78 to_remove = [v for v in it if isinstance(v, GraphNode) and v in self] 1a

79 if to_remove: 1a

80 (<_AdjacencySetCore>self._core).remove_edges(to_remove) 1a

81 return self 1a

82  

83 def update(self, *others): 

84 """Add edges to multiple nodes at once.""" 

85 nodes = [] 1dfab

86 for other in others: 1dfab

87 if isinstance(other, GraphNode): 1dfab

88 nodes.append(other) 

89 else: 

90 for n in other: 1dfab

91 if not isinstance(n, GraphNode): 1dfab

92 raise TypeError( 

93 f"expected GraphNode, got {type(n).__name__}") 

94 nodes.append(n) 1dfab

95 if not nodes: 1dfab

96 return 1b

97 new = [n for n in nodes if n not in self] 1dfab

98 if new: 1dfab

99 (<_AdjacencySetCore>self._core).add_edges(new) 1dfab

100  

101 def __ior__(self, it): 

102 """Add edges to all nodes in *it* in a single driver call.""" 

103 self.update(it) 1a

104 return self 1a

105  

106 def __repr__(self): 

107 return "{" + ", ".join(repr(n) for n in self) + "}" 1a

108  

109  

110# ---- cdef core holding a function pointer ------------------------------------ 

111  

112# Signature shared by driver_get_preds and driver_get_succs. 

113ctypedef cydriver.CUresult (*_adj_fn_t)( 

114 cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil 

115  

116  

117cdef class _AdjacencySetCore: 

118 """Cythonized core implementing AdjacencySetProxy""" 

119 cdef: 

120 GraphNodeHandle _h_node 

121 GraphHandle _h_graph 

122 _adj_fn_t _query_fn 

123 bint _is_fwd 

124  

125 def __init__(self, GraphNode node, bint is_fwd): 

126 self._h_node = node._h_node 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl

127 self._h_graph = graph_node_get_graph(node._h_node) 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl

128 self._is_fwd = is_fwd 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl

129 self._query_fn = driver_get_succs if is_fwd else driver_get_preds 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl

130  

131 cdef inline void _resolve_edge( 

132 self, GraphNode other, 

133 cydriver.CUgraphNode* c_from, 

134 cydriver.CUgraphNode* c_to) noexcept: 

135 if self._is_fwd: 1djfgacbekhl

136 c_from[0] = as_cu(self._h_node) 1djgabekhl

137 c_to[0] = as_cu(other._h_node) 1djgabekhl

138 else: 

139 c_from[0] = as_cu(other._h_node) 1fcbe

140 c_to[0] = as_cu(self._h_node) 1fcbe

141  

142 cdef list query(self): 

143 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbeh

144 if c_node == NULL: 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbeh

145 return [] 1Zh

146 cdef cydriver.CUgraphNode buf[16] 

147 cdef size_t count = 16 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

148 cdef size_t i 

149 with nogil: 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

150 HANDLE_RETURN(self._query_fn(c_node, buf, &count)) 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

151 if count <= 16: 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

152 return [GraphNode._create(self._h_graph, buf[i]) 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

153 for i in range(count)] 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

154 cdef vector[cydriver.CUgraphNode] nodes_vec 

155 nodes_vec.resize(count) 

156 with nogil: 

157 HANDLE_RETURN(self._query_fn( 

158 c_node, nodes_vec.data(), &count)) 

159 return [GraphNode._create(self._h_graph, nodes_vec[i]) 

160 for i in range(count)] 

161  

162 cdef bint contains(self, GraphNode other): 

163 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1nodmjfgacbekhl

164 cdef cydriver.CUgraphNode target = as_cu(other._h_node) 1nodmjfgacbekhl

165 if c_node == NULL or target == NULL: 1nodmjfgacbekhl

166 return False 1h

167 cdef cydriver.CUgraphNode buf[16] 

168 cdef size_t count = 16 1nodmjfgacbekl

169 cdef size_t i 

170 with nogil: 1nodmjfgacbekl

171 HANDLE_RETURN(self._query_fn(c_node, buf, &count)) 1nodmjfgacbekl

172  

173 # Fast path for small sets. 

174 if count <= 16: 1nodmjfgacbekl

175 for i in range(count): 1nodmjfgacbekl

176 if buf[i] == target: 1nomgacbe

177 return True 1nomgac

178 return False 1dmjfgacbekl

179  

180 # Fallback for large sets. 

181 cdef vector[cydriver.CUgraphNode] nodes_vec 

182 nodes_vec.resize(count) 

183 with nogil: 

184 HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count)) 

185 assert count == nodes_vec.size() 

186 for i in range(count): 

187 if nodes_vec[i] == target: 

188 return True 

189 return False 

190  

191 cdef Py_ssize_t count(self): 

192 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1Zacbh

193 if c_node == NULL: 1Zacbh

194 return 0 1Zh

195 cdef size_t n = 0 1acb

196 with nogil: 1acb

197 HANDLE_RETURN(self._query_fn(c_node, NULL, &n)) 1acb

198 return <Py_ssize_t>n 1acb

199  

200 cdef void add_edge(self, GraphNode other): 

201 cdef cydriver.CUgraphNode c_from, c_to 

202 self._resolve_edge(other, &c_from, &c_to) 1jgacekhl

203 with nogil: 1jgacekhl

204 HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) 1jgacekhl

205  

206 cdef void add_edges(self, list nodes): 

207 cdef size_t n = len(nodes) 1dfab

208 cdef vector[cydriver.CUgraphNode] from_vec 

209 cdef vector[cydriver.CUgraphNode] to_vec 

210 from_vec.resize(n) 1dfab

211 to_vec.resize(n) 1dfab

212 cdef size_t i 

213 for i in range(n): 1dfab

214 self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) 1dfab

215 with nogil: 1dfab

216 HANDLE_RETURN(driver_add_edges( 1dfab

217 as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) 

218  

219 cdef void remove_edge(self, GraphNode other): 

220 cdef cydriver.CUgraphNode c_from, c_to 

221 self._resolve_edge(other, &c_from, &c_to) 1gac

222 with nogil: 1gac

223 HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) 1gac

224  

225 cdef void remove_edges(self, list nodes): 

226 cdef size_t n = len(nodes) 1abe

227 cdef vector[cydriver.CUgraphNode] from_vec 

228 cdef vector[cydriver.CUgraphNode] to_vec 

229 from_vec.resize(n) 1abe

230 to_vec.resize(n) 1abe

231 cdef size_t i 

232 for i in range(n): 1abe

233 self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) 1abe

234 with nogil: 1abe

235 HANDLE_RETURN(driver_remove_edges( 1abe

236 as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) 

237  

238  

239# ---- driver wrappers: absorb CUDA version differences ---- 

240  

241cdef inline cydriver.CUresult driver_get_preds( 

242 cydriver.CUgraphNode node, cydriver.CUgraphNode* out, 

243 size_t* count) noexcept nogil: 

244 IF CUDA_CORE_BUILD_MAJOR >= 13: 

245 return cydriver.cuGraphNodeGetDependencies(node, out, NULL, count) 1nodBCDmpqErstuFGHIvwxyzAJKLMNOPQXYfcbe

246 ELSE: 

247 return cydriver.cuGraphNodeGetDependencies(node, out, count) 

248  

249  

250cdef inline cydriver.CUresult driver_get_succs( 

251 cydriver.CUgraphNode node, cydriver.CUgraphNode* out, 

252 size_t* count) noexcept nogil: 

253 IF CUDA_CORE_BUILD_MAJOR >= 13: 

254 return cydriver.cuGraphNodeGetDependentNodes(node, out, NULL, count) 1dpqrstuvwxyzARSTUVWjgacbekl

255 ELSE: 

256 return cydriver.cuGraphNodeGetDependentNodes(node, out, count) 

257  

258  

259cdef inline cydriver.CUresult driver_add_edges( 

260 cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, 

261 cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: 

262 IF CUDA_CORE_BUILD_MAJOR >= 13: 

263 return cydriver.cuGraphAddDependencies( 1djfgacbekhl

264 graph, from_arr, to_arr, NULL, count) 

265 ELSE: 

266 return cydriver.cuGraphAddDependencies( 

267 graph, from_arr, to_arr, count) 

268  

269  

270cdef inline cydriver.CUresult driver_remove_edges( 

271 cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, 

272 cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: 

273 IF CUDA_CORE_BUILD_MAJOR >= 13: 

274 return cydriver.cuGraphRemoveDependencies( 1gacbe

275 graph, from_arr, to_arr, NULL, count) 

276 ELSE: 

277 return cydriver.cuGraphRemoveDependencies( 

278 graph, from_arr, to_arr, count)