Coverage for cuda/core/graph/_adjacency_set_proxy.pyx: 85.71%
147 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-13 01:38 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5"""Mutable-set proxy for graph node predecessors and successors."""
7from libc.stddef cimport size_t
8from libcpp.vector cimport vector
9from cuda.bindings cimport cydriver
10from cuda.core.graph._graph_node cimport GraphNode
11from cuda.core._resource_handles cimport (
12 GraphHandle,
13 GraphNodeHandle,
14 as_cu,
15 graph_node_get_graph,
16)
17from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
18from collections.abc import Iterator, MutableSet, Set
19from typing import Any
22# ---- Python MutableSet wrapper ----------------------------------------------
24class AdjacencySetProxy(MutableSet[GraphNode]):
25 """Mutable set proxy for a node's predecessors or successors. Mutations
26 write through to the underlying CUDA graph."""
28 __slots__ = ("_core",)
30 def __init__(self, node: GraphNode, bint is_fwd) -> None:
31 self._core = _AdjacencySetCore(node, is_fwd) 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil
33 # Used by operators such as &|^ to create non-proxy views when needed.
34 @classmethod
35 def _from_iterable(cls, it) -> set[GraphNode]:
36 return set(it) 1a
38 # --- abstract methods required by MutableSet ---
40 def __contains__(self, x: object) -> bool:
41 if not isinstance(x, GraphNode): 1nodmjfgacbekil
42 return False
43 return (<_AdjacencySetCore>self._core).contains(<GraphNode>x) 1nodmjfgacbekil
45 def __iter__(self) -> Iterator[GraphNode]:
46 return iter((<_AdjacencySetCore>self._core).query()) 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYacbi
48 def __len__(self) -> int:
49 return (<_AdjacencySetCore>self._core).count() 1Zacbi
51 def add(self, value: GraphNode) -> None:
52 if not isinstance(value, GraphNode): 1jgacekil
53 raise TypeError(
54 f"expected GraphNode, got {type(value).__name__}")
55 if value in self: 1jgacekil
56 return 1a
57 (<_AdjacencySetCore>self._core).add_edge(<GraphNode>value) 1jgacekil
59 def discard(self, value: GraphNode) -> None:
60 if not isinstance(value, GraphNode): 1gac
61 return
62 if value not in self: 1gac
63 return 1a
64 (<_AdjacencySetCore>self._core).remove_edge(<GraphNode>value) 1gac
66 # --- override for bulk efficiency ---
68 def clear(self) -> None:
69 """Remove all edges in a single driver call."""
70 members = (<_AdjacencySetCore>self._core).query() 1dfabe
71 if members: 1dfabe
72 (<_AdjacencySetCore>self._core).remove_edges(members) 1abe
74 def __isub__(self, it: Set[Any]) -> "AdjacencySetProxy":
75 """Remove edges to all nodes in *it* in a single driver call."""
76 if it is self: 1a
77 self.clear()
78 else:
79 to_remove = [v for v in it if isinstance(v, GraphNode) and v in self] 1a
80 if to_remove: 1a
81 (<_AdjacencySetCore>self._core).remove_edges(to_remove) 1a
82 return self 1a
84 def update(self, *others) -> None:
85 """Add edges to multiple nodes at once."""
86 nodes = [] 1dfab
87 for other in others: 1dfab
88 if isinstance(other, GraphNode): 1dfab
89 nodes.append(other)
90 else:
91 for n in other: 1dfab
92 if not isinstance(n, GraphNode): 1dfab
93 raise TypeError(
94 f"expected GraphNode, got {type(n).__name__}")
95 nodes.append(n) 1dfab
96 if not nodes: 1dfab
97 return 1b
98 new = [n for n in nodes if n not in self] 1dfab
99 if new: 1dfab
100 (<_AdjacencySetCore>self._core).add_edges(new) 1dfab
102 def __ior__(self, it: Set[Any]) -> "AdjacencySetProxy":
103 """Add edges to all nodes in *it* in a single driver call."""
104 self.update(it) 1a
105 return self 1a
107 def __repr__(self) -> str:
108 return "{" + ", ".join(repr(n) for n in self) + "}" 1a
111# ---- cdef core holding a function pointer ------------------------------------
113# Signature shared by driver_get_preds and driver_get_succs.
114ctypedef cydriver.CUresult (*_adj_fn_t)(
115 cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil
118cdef class _AdjacencySetCore:
119 """Cythonized core implementing AdjacencySetProxy"""
120 cdef:
121 GraphNodeHandle _h_node
122 GraphHandle _h_graph
123 _adj_fn_t _query_fn
124 bint _is_fwd
126 def __init__(self, GraphNode node, bint is_fwd):
127 self._h_node = node._h_node 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil
128 self._h_graph = graph_node_get_graph(node._h_node) 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil
129 self._is_fwd = is_fwd 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil
130 self._query_fn = driver_get_succs if is_fwd else driver_get_preds 1noZdBCDmpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYjfg0acbekil
132 cdef inline void _resolve_edge(
133 self, GraphNode other,
134 cydriver.CUgraphNode* c_from,
135 cydriver.CUgraphNode* c_to) noexcept:
136 if self._is_fwd: 1djfgacbekil
137 c_from[0] = as_cu(self._h_node) 1djgabekil
138 c_to[0] = as_cu(other._h_node) 1djgabekil
139 else:
140 c_from[0] = as_cu(other._h_node) 1fcbe
141 c_to[0] = as_cu(self._h_node) 1fcbe
143 cdef list query(self):
144 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbei
145 if c_node == NULL: 1ZdBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbei
146 return [] 1Zi
147 cdef cydriver.CUgraphNode buf[16]
148 cdef size_t count = 16 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
149 cdef size_t i
150 with nogil: 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
151 HANDLE_RETURN(self._query_fn(c_node, buf, &count)) 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
152 if count <= 16: 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
153 return [GraphNode._create(self._h_graph, buf[i]) 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
154 for i in range(count)] 1dBCDpqErstuFGHIvwxyzAJKLMNOPQRSTUVWXYfacbe
155 cdef vector[cydriver.CUgraphNode] nodes_vec
156 nodes_vec.resize(count)
157 with nogil:
158 HANDLE_RETURN(self._query_fn(
159 c_node, nodes_vec.data(), &count))
160 return [GraphNode._create(self._h_graph, nodes_vec[i])
161 for i in range(count)]
163 cdef bint contains(self, GraphNode other):
164 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1nodmjfgacbekil
165 cdef cydriver.CUgraphNode target = as_cu(other._h_node) 1nodmjfgacbekil
166 if c_node == NULL or target == NULL: 1nodmjfgacbekil
167 return False 1i
168 cdef cydriver.CUgraphNode buf[16]
169 cdef size_t count = 16 1nodmjfgacbekl
170 cdef size_t i
171 with nogil: 1nodmjfgacbekl
172 HANDLE_RETURN(self._query_fn(c_node, buf, &count)) 1nodmjfgacbekl
174 # Fast path for small sets.
175 if count <= 16: 1nodmjfgacbekl
176 for i in range(count): 1nodmjfgacbekl
177 if buf[i] == target: 1nomgacbe
178 return True 1nomgac
179 return False 1dmjfgacbekl
181 # Fallback for large sets.
182 cdef vector[cydriver.CUgraphNode] nodes_vec
183 nodes_vec.resize(count)
184 with nogil:
185 HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count))
186 assert count == nodes_vec.size()
187 for i in range(count):
188 if nodes_vec[i] == target:
189 return True
190 return False
192 cdef Py_ssize_t count(self):
193 cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) 1Zacbi
194 if c_node == NULL: 1Zacbi
195 return 0 1Zi
196 cdef size_t n = 0 1acb
197 with nogil: 1acb
198 HANDLE_RETURN(self._query_fn(c_node, NULL, &n)) 1acb
199 return <Py_ssize_t>n 1acb
201 cdef void add_edge(self, GraphNode other):
202 cdef cydriver.CUgraphNode c_from, c_to
203 self._resolve_edge(other, &c_from, &c_to) 1jgacekil
204 with nogil: 1jgacekil
205 HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) 1jgacekil
207 cdef void add_edges(self, list nodes):
208 cdef size_t n = len(nodes) 1dfab
209 cdef vector[cydriver.CUgraphNode] from_vec
210 cdef vector[cydriver.CUgraphNode] to_vec
211 from_vec.resize(n) 1dfab
212 to_vec.resize(n) 1dfab
213 cdef size_t i
214 for i in range(n): 1dfab
215 self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) 1dfab
216 with nogil: 1dfab
217 HANDLE_RETURN(driver_add_edges( 1dfab
218 as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))
220 cdef void remove_edge(self, GraphNode other):
221 cdef cydriver.CUgraphNode c_from, c_to
222 self._resolve_edge(other, &c_from, &c_to) 1gac
223 with nogil: 1gac
224 HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) 1gac
226 cdef void remove_edges(self, list nodes):
227 cdef size_t n = len(nodes) 1abe
228 cdef vector[cydriver.CUgraphNode] from_vec
229 cdef vector[cydriver.CUgraphNode] to_vec
230 from_vec.resize(n) 1abe
231 to_vec.resize(n) 1abe
232 cdef size_t i
233 for i in range(n): 1abe
234 self._resolve_edge(<GraphNode>nodes[i], &from_vec[i], &to_vec[i]) 1abe
235 with nogil: 1abe
236 HANDLE_RETURN(driver_remove_edges( 1abe
237 as_cu(self._h_graph), from_vec.data(), to_vec.data(), n))
240# ---- driver wrappers: absorb CUDA version differences ----
242cdef inline cydriver.CUresult driver_get_preds(
243 cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
244 size_t* count) noexcept nogil:
245 IF CUDA_CORE_BUILD_MAJOR >= 13:
246 return cydriver.cuGraphNodeGetDependencies(node, out, NULL, count) 1nodBCDmpqErstuFGHIvwxyzAJKLMNOPQXYfcbe
247 ELSE:
248 return cydriver.cuGraphNodeGetDependencies(node, out, count)
251cdef inline cydriver.CUresult driver_get_succs(
252 cydriver.CUgraphNode node, cydriver.CUgraphNode* out,
253 size_t* count) noexcept nogil:
254 IF CUDA_CORE_BUILD_MAJOR >= 13:
255 return cydriver.cuGraphNodeGetDependentNodes(node, out, NULL, count) 1dpqrstuvwxyzARSTUVWjgacbekl
256 ELSE:
257 return cydriver.cuGraphNodeGetDependentNodes(node, out, count)
260cdef inline cydriver.CUresult driver_add_edges(
261 cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
262 cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
263 IF CUDA_CORE_BUILD_MAJOR >= 13:
264 return cydriver.cuGraphAddDependencies( 1djfgacbekil
265 graph, from_arr, to_arr, NULL, count)
266 ELSE:
267 return cydriver.cuGraphAddDependencies(
268 graph, from_arr, to_arr, count)
271cdef inline cydriver.CUresult driver_remove_edges(
272 cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr,
273 cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil:
274 IF CUDA_CORE_BUILD_MAJOR >= 13:
275 return cydriver.cuGraphRemoveDependencies( 1gacbe
276 graph, from_arr, to_arr, NULL, count)
277 ELSE:
278 return cydriver.cuGraphRemoveDependencies(
279 graph, from_arr, to_arr, count)