Coverage for cuda / core / graph / _adjacency_set_proxy.pyx: 86.99%
146 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 01:27 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5"""Mutable-set proxy for graph node predecessors and successors."""
7from libc.stddef cimport size_t
8from libcpp.vector cimport vector
9from cuda.bindings cimport cydriver
10from cuda.core.graph._graph_node cimport GraphNode
11from cuda.core._resource_handles cimport (
12 GraphHandle,
13 GraphNodeHandle,
14 as_cu,
15 graph_node_get_graph,
16)
17from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
18from collections.abc import MutableSet
21# ---- Python MutableSet wrapper ----------------------------------------------
23class AdjacencySetProxy(MutableSet):
24 """Mutable set proxy for a node's predecessors or successors. Mutations
25 write through to the underlying CUDA graph."""
27 __slots__ = ("_core",)
29 def __init__(self, node, bint is_fwd):
30 self._core = _AdjacencySetCore(node, is_fwd) 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl
32 # Used by operators such as &|^ to create non-proxy views when needed.
33 @classmethod
34 def _from_iterable(cls, it):
35 return set(it) 1a
37 # --- abstract methods required by MutableSet ---
39 def __contains__(self, x):
40 if not isinstance(x, GraphNode): 1nodmjfgacbekhl
41 return False
42 return (<_AdjacencySetCore>self._core).contains(<GraphNode>x) 1nodmjfgacbekhl
44 def __iter__(self):
45 return iter((<_AdjacencySetCore>self._core).query()) 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYacbh
47 def __len__(self):
48 return (<_AdjacencySetCore>self._core).count() 1Zacbh
50 def add(self, value):
51 if not isinstance(value, GraphNode): 1jg0acekhl
52 raise TypeError( 10
53 f"expected GraphNode, got {type(value).__name__}") 10
54 if value in self: 1jgacekhl
55 return 1a
56 (<_AdjacencySetCore>self._core).add_edge(<GraphNode>value) 1jgacekhl
58 def discard(self, value):
59 if not isinstance(value, GraphNode): 1gac
60 return
61 if value not in self: 1gac
62 return 1a
63 (<_AdjacencySetCore>self._core).remove_edge(<GraphNode>value) 1gac
65 # --- override for bulk efficiency ---
67 def clear(self):
68 """Remove all edges in a single driver call."""
69 members = (<_AdjacencySetCore>self._core).query() 1dfabe
70 if members: 1dfabe
71 (<_AdjacencySetCore>self._core).remove_edges(members) 1abe
73 def __isub__(self, it):
74 """Remove edges to all nodes in *it* in a single driver call."""
75 if it is self: 1a
76 self.clear()
77 else:
78 to_remove = [v for v in it if isinstance(v, GraphNode) and v in self] 1a
79 if to_remove: 1a
80 (<_AdjacencySetCore>self._core).remove_edges(to_remove) 1a
81 return self 1a
83 def update(self, *others):
84 """Add edges to multiple nodes at once."""
85 nodes = [] 1dfab
86 for other in others: 1dfab
87 if isinstance(other, GraphNode): 1dfab
88 nodes.append(other)
89 else:
90 for n in other: 1dfab
91 if not isinstance(n, GraphNode): 1dfab
92 raise TypeError(
93 f"expected GraphNode, got {type(n).__name__}")
94 nodes.append(n) 1dfab
95 if not nodes: 1dfab
96 return 1b
97 new = [n for n in nodes if n not in self] 1dfab
98 if new: 1dfab
99 (<_AdjacencySetCore>self._core).add_edges(new) 1dfab
101 def __ior__(self, it):
102 """Add edges to all nodes in *it* in a single driver call."""
103 self.update(it) 1a
104 return self 1a
106 def __repr__(self):
107 return "{" + ", ".join(repr(n) for n in self) + "}" 1a
110# ---- cdef core holding a function pointer ------------------------------------
112# Signature shared by driver_get_preds and driver_get_succs.
113ctypedef cydriver.CUresult (*_adj_fn_t)(
114 cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil
117cdef class _AdjacencySetCore:
118 """Cythonized core implementing AdjacencySetProxy"""
119 cdef:
120 GraphNodeHandle _h_node
121 GraphHandle _h_graph
122 _adj_fn_t _query_fn
123 bint _is_fwd
125 def __init__(self, GraphNode node, bint is_fwd):
126 self._h_node = node._h_node 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl
127 self._h_graph = graph_node_get_graph(node._h_node) 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl
128 self._is_fwd = is_fwd 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl
129 self._query_fn = driver_get_succs if is_fwd else driver_get_preds 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekhl
131 cdef inline void _resolve_edge(
132 self, GraphNode other,
133 cydriver.CUgraphNode* c_from,
134 cydriver.CUgraphNode* c_to) noexcept:
135 if self._is_fwd: 1djfgacbekhl
136 c_from[0] = as_cu(self._h_node) 1djgabekhl
137 c_to[0] = as_cu(other._h_node) 1djgabekhl
138 else:
139 c_from[0] = as_cu(other._h_node) 1fcbe
140 c_to[0] = as_cu(self._h_node) 1fcbe
142 cdef list query(self):
143 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbeh
144 if c_node == NULL: 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbeh
145 return [] 1Zh
146 cdef cydriver.CUgraphNode buf[16]
147 cdef size_t count = 16 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
148 cdef size_t i
149 with nogil: 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
150 HANDLE_RETURN(self._query_fn(c_node, buf, &count)) 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
151 if count <= 16: 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
152 return [GraphNode._create(self._h_graph, buf[i]) 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
153 for i in range(count)] 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
154 cdef vector[cydriver.CUgraphNode] nodes_vec
155 nodes_vec.resize(count)
156 with nogil:
157 HANDLE_RETURN(self._query_fn(
158 c_node, nodes_vec.data(), &count))
159 return [GraphNode._create(self._h_graph, nodes_vec[i])
160 for i in range(count)]
162 cdef bint contains(self, GraphNode other):
163 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1nodmjfgacbekhl
164 cdef cydriver.CUgraphNode target = as_cu(other._h_node) 1nodmjfgacbekhl
165 if c_node == NULL or target == NULL: 1nodmjfgacbekhl
166 return False 1h
167 cdef cydriver.CUgraphNode buf[16]
168 cdef size_t count = 16 1nodmjfgacbekl
169 cdef size_t i
170 with nogil: 1nodmjfgacbekl
171 HANDLE_RETURN(self._query_fn(c_node, buf, &count)) 1nodmjfgacbekl
173 # Fast path for small sets.
174 if count <= 16: 1nodmjfgacbekl
175 for i in range(count): 1nodmjfgacbekl
176 if buf[i] == target: 1nomgacbe
177 return True 1nomgac
178 return False 1dmjfgacbekl
180 # Fallback for large sets.
181 cdef vector[cydriver.CUgraphNode] nodes_vec
182 nodes_vec.resize(count)
183 with nogil:
184 HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count))
185 assert count == nodes_vec.size()
186 for i in range(count):
187 if nodes_vec[i] == target:
188 return True
189 return False
191 cdef Py_ssize_t count(self):
192 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1Zacbh
193 if c_node == NULL: 1Zacbh
194 return 0 1Zh
195 cdef size_t n = 0 1acb
196 with nogil: 1acb
197 HANDLE_RETURN(self._query_fn(c_node, NULL, &n)) 1acb
198 return <Py_ssize_t>n 1acb
200 cdef void add_edge(self, GraphNode other):
201 cdef cydriver.CUgraphNode c_from, c_to
202 self._resolve_edge(other, &c_from, &c_to) 1jgacekhl
203 with nogil: 1jgacekhl
204 HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) 1jgacekhl
206 cdef void add_edges(self, list nodes):
207 cdef size_t n = len(nodes) 1dfab
208 cdef vector[cydriver.CUgraphNode] from_vec
209 cdef vector[cydriver.CUgraphNode] to_vec
210 from_vec.resize(n) 1dfab
211 to_vec.resize(n) 1dfab
212 cdef size_t i
213 for i in range(n): 1dfab
214 self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) 1dfab
215 with nogil: 1dfab
216 HANDLE_RETURN(driver_add_edges( 1dfab
217 as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))
219 cdef void remove_edge(self, GraphNode other):
220 cdef cydriver.CUgraphNode c_from, c_to
221 self._resolve_edge(other, &c_from, &c_to) 1gac
222 with nogil: 1gac
223 HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) 1gac
225 cdef void remove_edges(self, list nodes):
226 cdef size_t n = len(nodes) 1abe
227 cdef vector[cydriver.CUgraphNode] from_vec
228 cdef vector[cydriver.CUgraphNode] to_vec
229 from_vec.resize(n) 1abe
230 to_vec.resize(n) 1abe
231 cdef size_t i
232 for i in range(n): 1abe
233 self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) 1abe
234 with nogil: 1abe
235 HANDLE_RETURN(driver_remove_edges( 1abe
236 as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))
239# ---- driver wrappers: absorb CUDA version differences ----
241cdef inline cydriver.CUresult driver_get_preds(
242 cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
243 size_t* count) noexcept nogil:
244 IF CUDA_CORE_BUILD_MAJOR >= 13:
245 return cydriver.cuGraphNodeGetDependencies(node, out, NULL, count) 1nodBCDmpqErstuFGHIvwxyzAJKLMNOPQXYfcbe
246 ELSE:
247 return cydriver.cuGraphNodeGetDependencies(node, out, count)
250cdef inline cydriver.CUresult driver_get_succs(
251 cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
252 size_t* count) noexcept nogil:
253 IF CUDA_CORE_BUILD_MAJOR >= 13:
254 return cydriver.cuGraphNodeGetDependentNodes(node, out, NULL, count) 1dpqrstuvwxyzARSTUVWjgacbekl
255 ELSE:
256 return cydriver.cuGraphNodeGetDependentNodes(node, out, count)
259cdef inline cydriver.CUresult driver_add_edges(
260 cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
261 cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
262 IF CUDA_CORE_BUILD_MAJOR >= 13:
263 return cydriver.cuGraphAddDependencies( 1djfgacbekhl
264 graph, from_arr, to_arr, NULL, count)
265 ELSE:
266 return cydriver.cuGraphAddDependencies(
267 graph, from_arr, to_arr, count)
270cdef inline cydriver.CUresult driver_remove_edges(
271 cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
272 cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
273 IF CUDA_CORE_BUILD_MAJOR >= 13:
274 return cydriver.cuGraphRemoveDependencies( 1gacbe
275 graph, from_arr, to_arr, NULL, count)
276 ELSE:
277 return cydriver.cuGraphRemoveDependencies(
278 graph, from_arr, to_arr, count)