Coverage for cuda/core/graph/_adjacency_set_proxy.pyx: 85.71%

147 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-13 01:38 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5"""Mutable-set proxy for graph node predecessors and successors.""" 

6  

7from libc.stddef cimport size_t 

8from libcpp.vector cimport vector 

9from cuda.bindings cimport cydriver 

10from cuda.core.graph._graph_node cimport GraphNode 

11from cuda.core._resource_handles cimport ( 

12 GraphHandle, 

13 GraphNodeHandle, 

14 as_cu, 

15 graph_node_get_graph, 

16) 

17from cuda.core._utils.cuda_utils cimport HANDLE_RETURN 

18from collections.abc import Iterator, MutableSet, Set 

19from typing import Any 

20  

21  

22# ---- Python MutableSet wrapper ---------------------------------------------- 

23  

24class AdjacencySetProxy(MutableSet[GraphNode]): 

25 """Mutable set proxy for a node's predecessors or successors. Mutations 

26 write through to the underlying CUDA graph.""" 

27  

28 __slots__ = ("_core",) 

29  

30 def __init__(self, node: GraphNode, bint is_fwd) -> None: 

31 self._core = _AdjacencySetCore(node, is_fwd) 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil

32  

33 # Used by operators such as &|^ to create non-proxy views when needed. 

34 @classmethod 

35 def _from_iterable(cls, it) -> set[GraphNode]: 

36 return set(it) 1a

37  

38 # --- abstract methods required by MutableSet --- 

39  

40 def __contains__(self, x: object) -> bool: 

41 if not isinstance(x, GraphNode): 1nodmjfgacbekil

42 return False 

43 return (<_AdjacencySetCore>self._core).contains(<GraphNode>x) 1nodmjfgacbekil

44  

45 def __iter__(self) -> Iterator[GraphNode]: 

46 return iter((<_AdjacencySetCore>self._core).query()) 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYacbi

47  

48 def __len__(self) -> int: 

49 return (<_AdjacencySetCore>self._core).count() 1Zacbi

50  

51 def add(self, value: GraphNode) -> None: 

52 if not isinstance(value, GraphNode): 1jgacekil

53 raise TypeError( 

54 f"expected GraphNode, got {type(value).__name__}") 

55 if value in self: 1jgacekil

56 return 1a

57 (<_AdjacencySetCore>self._core).add_edge(<GraphNode>value) 1jgacekil

58  

59 def discard(self, value: GraphNode) -> None: 

60 if not isinstance(value, GraphNode): 1gac

61 return 

62 if value not in self: 1gac

63 return 1a

64 (<_AdjacencySetCore>self._core).remove_edge(<GraphNode>value) 1gac

65  

66 # --- override for bulk efficiency --- 

67  

68 def clear(self) -> None: 

69 """Remove all edges in a single driver call.""" 

70 members = (<_AdjacencySetCore>self._core).query() 1dfabe

71 if members: 1dfabe

72 (<_AdjacencySetCore>self._core).remove_edges(members) 1abe

73  

74 def __isub__(self, it: Set[Any]) -> "AdjacencySetProxy": 

75 """Remove edges to all nodes in *it* in a single driver call.""" 

76 if it is self: 1a

77 self.clear() 

78 else: 

79 to_remove = [v for v in it if isinstance(v, GraphNode) and v in self] 1a

80 if to_remove: 1a

81 (<_AdjacencySetCore>self._core).remove_edges(to_remove) 1a

82 return self 1a

83  

84 def update(self, *others) -> None: 

85 """Add edges to multiple nodes at once.""" 

86 nodes = [] 1dfab

87 for other in others: 1dfab

88 if isinstance(other, GraphNode): 1dfab

89 nodes.append(other) 

90 else: 

91 for n in other: 1dfab

92 if not isinstance(n, GraphNode): 1dfab

93 raise TypeError( 

94 f"expected GraphNode, got {type(n).__name__}") 

95 nodes.append(n) 1dfab

96 if not nodes: 1dfab

97 return 1b

98 new = [n for n in nodes if n not in self] 1dfab

99 if new: 1dfab

100 (<_AdjacencySetCore>self._core).add_edges(new) 1dfab

101  

102 def __ior__(self, it: Set[Any]) -> "AdjacencySetProxy": 

103 """Add edges to all nodes in *it* in a single driver call.""" 

104 self.update(it) 1a

105 return self 1a

106  

107 def __repr__(self) -> str: 

108 return "{" + ", ".join(repr(n) for n in self) + "}" 1a

109  

110  

111# ---- cdef core holding a function pointer ------------------------------------ 

112  

113# Signature shared by driver_get_preds and driver_get_succs. 

114ctypedef cydriver.CUresult (*_adj_fn_t)( 

115 cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil 

116  

117  

118cdef class _AdjacencySetCore: 

119 """Cythonized core implementing AdjacencySetProxy""" 

120 cdef: 

121 GraphNodeHandle _h_node 

122 GraphHandle _h_graph 

123 _adj_fn_t _query_fn 

124 bint _is_fwd 

125  

126 def __init__(self, GraphNode node, bint is_fwd): 

127 self._h_node = node._h_node 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil

128 self._h_graph = graph_node_get_graph(node._h_node) 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil

129 self._is_fwd = is_fwd 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil

130 self._query_fn = driver_get_succs if is_fwd else driver_get_preds 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil

131  

132 cdef inline void _resolve_edge( 

133 self, GraphNode other, 

134 cydriver.CUgraphNode* c_from, 

135 cydriver.CUgraphNode* c_to) noexcept: 

136 if self._is_fwd: 1djfgacbekil

137 c_from[0] = as_cu(self._h_node) 1djgabekil

138 c_to[0] = as_cu(other._h_node) 1djgabekil

139 else: 

140 c_from[0] = as_cu(other._h_node) 1fcbe

141 c_to[0] = as_cu(self._h_node) 1fcbe

142  

143 cdef list query(self): 

144 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbei

145 if c_node == NULL: 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbei

146 return [] 1Zi

147 cdef cydriver.CUgraphNode buf[16] 

148 cdef size_t count = 16 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

149 cdef size_t i 

150 with nogil: 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

151 HANDLE_RETURN(self._query_fn(c_node, buf, &count)) 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

152 if count <= 16: 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

153 return [GraphNode._create(self._h_graph, buf[i]) 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

154 for i in range(count)] 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe

155 cdef vector[cydriver.CUgraphNode] nodes_vec 

156 nodes_vec.resize(count) 

157 with nogil: 

158 HANDLE_RETURN(self._query_fn( 

159 c_node, nodes_vec.data(), &count)) 

160 return [GraphNode._create(self._h_graph, nodes_vec[i]) 

161 for i in range(count)] 

162  

163 cdef bint contains(self, GraphNode other): 

164 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1nodmjfgacbekil

165 cdef cydriver.CUgraphNode target = as_cu(other._h_node) 1nodmjfgacbekil

166 if c_node == NULL or target == NULL: 1nodmjfgacbekil

167 return False 1i

168 cdef cydriver.CUgraphNode buf[16] 

169 cdef size_t count = 16 1nodmjfgacbekl

170 cdef size_t i 

171 with nogil: 1nodmjfgacbekl

172 HANDLE_RETURN(self._query_fn(c_node, buf, &count)) 1nodmjfgacbekl

173  

174 # Fast path for small sets. 

175 if count <= 16: 1nodmjfgacbekl

176 for i in range(count): 1nodmjfgacbekl

177 if buf[i] == target: 1nomgacbe

178 return True 1nomgac

179 return False 1dmjfgacbekl

180  

181 # Fallback for large sets. 

182 cdef vector[cydriver.CUgraphNode] nodes_vec 

183 nodes_vec.resize(count) 

184 with nogil: 

185 HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count)) 

186 assert count == nodes_vec.size() 

187 for i in range(count): 

188 if nodes_vec[i] == target: 

189 return True 

190 return False 

191  

192 cdef Py_ssize_t count(self): 

193 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1Zacbi

194 if c_node == NULL: 1Zacbi

195 return 0 1Zi

196 cdef size_t n = 0 1acb

197 with nogil: 1acb

198 HANDLE_RETURN(self._query_fn(c_node, NULL, &n)) 1acb

199 return <Py_ssize_t>n 1acb

200  

201 cdef void add_edge(self, GraphNode other): 

202 cdef cydriver.CUgraphNode c_from, c_to 

203 self._resolve_edge(other, &c_from, &c_to) 1jgacekil

204 with nogil: 1jgacekil

205 HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) 1jgacekil

206  

207 cdef void add_edges(self, list nodes): 

208 cdef size_t n = len(nodes) 1dfab

209 cdef vector[cydriver.CUgraphNode] from_vec 

210 cdef vector[cydriver.CUgraphNode] to_vec 

211 from_vec.resize(n) 1dfab

212 to_vec.resize(n) 1dfab

213 cdef size_t i 

214 for i in range(n): 1dfab

215 self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) 1dfab

216 with nogil: 1dfab

217 HANDLE_RETURN(driver_add_edges( 1dfab

218 as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) 

219  

220 cdef void remove_edge(self, GraphNode other): 

221 cdef cydriver.CUgraphNode c_from, c_to 

222 self._resolve_edge(other, &c_from, &c_to) 1gac

223 with nogil: 1gac

224 HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) 1gac

225  

226 cdef void remove_edges(self, list nodes): 

227 cdef size_t n = len(nodes) 1abe

228 cdef vector[cydriver.CUgraphNode] from_vec 

229 cdef vector[cydriver.CUgraphNode] to_vec 

230 from_vec.resize(n) 1abe

231 to_vec.resize(n) 1abe

232 cdef size_t i 

233 for i in range(n): 1abe

234 self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) 1abe

235 with nogil: 1abe

236 HANDLE_RETURN(driver_remove_edges( 1abe

237 as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) 

238  

239  

240# ---- driver wrappers: absorb CUDA version differences ---- 

241  

242cdef inline cydriver.CUresult driver_get_preds( 

243 cydriver.CUgraphNode node, cydriver.CUgraphNode* out, 

244 size_t* count) noexcept nogil: 

245 IF CUDA_CORE_BUILD_MAJOR >= 13: 

246 return cydriver.cuGraphNodeGetDependencies(node, out, NULL, count) 1nodBCDmpqErstuFGHIvwxyzAJKLMNOPQXYfcbe

247 ELSE: 

248 return cydriver.cuGraphNodeGetDependencies(node, out, count) 

249  

250  

251cdef inline cydriver.CUresult driver_get_succs( 

252 cydriver.CUgraphNode node, cydriver.CUgraphNode* out, 

253 size_t* count) noexcept nogil: 

254 IF CUDA_CORE_BUILD_MAJOR >= 13: 

255 return cydriver.cuGraphNodeGetDependentNodes(node, out, NULL, count) 1dpqrstuvwxyzARSTUVWjgacbekl

256 ELSE: 

257 return cydriver.cuGraphNodeGetDependentNodes(node, out, count) 

258  

259  

260cdef inline cydriver.CUresult driver_add_edges( 

261 cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, 

262 cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: 

263 IF CUDA_CORE_BUILD_MAJOR >= 13: 

264 return cydriver.cuGraphAddDependencies( 1djfgacbekil

265 graph, from_arr, to_arr, NULL, count) 

266 ELSE: 

267 return cydriver.cuGraphAddDependencies( 

268 graph, from_arr, to_arr, count) 

269  

270  

271cdef inline cydriver.CUresult driver_remove_edges( 

272 cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, 

273 cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: 

274 IF CUDA_CORE_BUILD_MAJOR >= 13: 

275 return cydriver.cuGraphRemoveDependencies( 1gacbe

276 graph, from_arr, to_arr, NULL, count) 

277 ELSE: 

278 return cydriver.cuGraphRemoveDependencies( 

279 graph, from_arr, to_arr, count)