Coverage for cuda / core / experimental / _stream.pyx: 88%
201 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-10 01:19 +0000
1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7from libc.stdint cimport uintptr_t, INT32_MIN
8from libc.stdlib cimport strtol, getenv
10from cuda.bindings cimport cydriver
12from cuda.core.experimental._event cimport Event as cyEvent
13from cuda.core.experimental._utils.cuda_utils cimport (
14 check_or_create_options,
15 CU_CONTEXT_INVALID,
16 get_device_from_ctx,
17 HANDLE_RETURN,
18)
20import cython
21import warnings
22from dataclasses import dataclass
23from typing import TYPE_CHECKING, Optional, Protocol, Union
25if TYPE_CHECKING:
26 import cuda.bindings
27 from cuda.core.experimental._device import Device
28from cuda.core.experimental._context import Context
29from cuda.core.experimental._event import Event, EventOptions
30from cuda.core.experimental._graph import GraphBuilder
31from cuda.core.experimental._utils.cuda_utils import (
32 driver,
33)
36@dataclass
37cdef class StreamOptions:
38 """Customizable :obj:`~_stream.Stream` options.
40 Attributes
41 ----------
42 nonblocking : bool, optional
43 Stream does not synchronize with the NULL stream. (Default to True)
44 priority : int, optional
45 Stream priority where lower number represents a
46 higher priority. (Default to lowest priority)
48 """
50 nonblocking : cython.bint = True
51 priority: Optional[int] = None
54class IsStreamT(Protocol):
55 def __cuda_stream__(self) -> tuple[int, int]:
56 """
57 For any Python object that is meant to be interpreted as a CUDA stream, the intent
58 can be communicated by implementing this protocol that returns a 2-tuple: The protocol
59 version number (currently ``0``) and the address of ``cudaStream_t``. Both values
60 should be Python `int`.
61 """
62 ...
65cdef class Stream:
66 """Represent a queue of GPU operations that are executed in a specific order.
68 Applications use streams to control the order of execution for
69 GPU work. Work within a single stream are executed sequentially.
70 Whereas work across multiple streams can be further controlled
71 using stream priorities and :obj:`~_event.Event` managements.
73 Advanced users can utilize default streams for enforce complex
74 implicit synchronization behaviors.
76 Directly creating a :obj:`~_stream.Stream` is not supported due to ambiguity.
77 New streams should instead be created through a :obj:`~_device.Device`
78 object, or created directly through using an existing handle
79 using Stream.from_handle().
80 """
81 def __cinit__(self):
82 self._handle = <cydriver.CUstream>(NULL)
83 self._owner = None
84 self._builtin = False
85 self._nonblocking = -1 # lazy init'd
86 self._priority = INT32_MIN # lazy init'd
87 self._device_id = cydriver.CU_DEVICE_INVALID # lazy init'd
88 self._ctx_handle = CU_CONTEXT_INVALID # lazy init'd
90 def __init__(self, *args, **kwargs):
91 raise RuntimeError(
92 "Stream objects cannot be instantiated directly. "
93 "Please use Device APIs (create_stream) or other Stream APIs (from_handle)."
94 )
96 @classmethod
97 def _legacy_default(cls):
98 cdef Stream self = Stream.__new__(cls)
99 self._handle = <cydriver.CUstream>(cydriver.CU_STREAM_LEGACY)
100 self._builtin = True
101 return self
103 @classmethod
104 def _per_thread_default(cls):
105 cdef Stream self = Stream.__new__(cls)
106 self._handle = <cydriver.CUstream>(cydriver.CU_STREAM_PER_THREAD)
107 self._builtin = True
108 return self
110 @classmethod
111 def _init(cls, obj: IsStreamT | None = None, options=None, device_id: int = None):
112 cdef Stream self = Stream.__new__(cls)
114 if obj is not None and options is not None:
115 raise ValueError("obj and options cannot be both specified")
116 if obj is not None:
117 self._handle = _handle_from_stream_protocol(obj)
118 # TODO: check if obj is created under the current context/device
119 self._owner = obj
120 return self
122 cdef StreamOptions opts = check_or_create_options(StreamOptions, options, "Stream options")
123 nonblocking = opts.nonblocking
124 priority = opts.priority
126 flags = cydriver.CUstream_flags.CU_STREAM_NON_BLOCKING if nonblocking else cydriver.CUstream_flags.CU_STREAM_DEFAULT
127 # TODO: we might want to consider memoizing high/low per CUDA context and avoid this call
128 cdef int high, low
129 with nogil:
130 HANDLE_RETURN(cydriver.cuCtxGetStreamPriorityRange(&high, &low))
131 cdef int prio
132 if priority is not None:
133 prio = priority
134 if not (low <= prio <= high):
135 raise ValueError(f"{priority=} is out of range {[low, high]}")
136 else:
137 prio = high
139 cdef cydriver.CUstream s
140 with nogil:
141 HANDLE_RETURN(cydriver.cuStreamCreateWithPriority(&s, flags, prio))
142 self._handle = s
143 self._nonblocking = int(nonblocking)
144 self._priority = prio
145 self._device_id = device_id if device_id is not None else self._device_id
146 return self
148 def __dealloc__(self):
149 self.close()
151 cpdef close(self):
152 """Destroy the stream.
154 Destroy the stream if we own it. Borrowed foreign stream
155 object will instead have their references released.
157 """
158 if self._owner is None:
159 if self._handle and not self._builtin:
160 with nogil:
161 HANDLE_RETURN(cydriver.cuStreamDestroy(self._handle))
162 else:
163 self._owner = None
164 self._handle = <cydriver.CUstream>(NULL)
166 def __cuda_stream__(self) -> tuple[int, int]:
167 """Return an instance of a __cuda_stream__ protocol."""
168 return (0, <uintptr_t>(self._handle))
170 def __hash__(self) -> int:
171 # Ensure context is initialized for hash consistency
172 if self._ctx_handle == CU_CONTEXT_INVALID:
173 self._get_context()
174 return hash((<uintptr_t>(self._ctx_handle), <uintptr_t>(self._handle)))
176 def __eq__(self, other) -> bool:
177 if not isinstance(other, Stream):
178 return NotImplemented
179 cdef Stream _other = <Stream>other
180 # Fast path: compare handles first
181 if <uintptr_t>(self._handle) != <uintptr_t>((_other)._handle):
182 return False
183 # Ensure contexts are initialized for both streams
184 if self._ctx_handle == CU_CONTEXT_INVALID:
185 self._get_context()
186 if _other._ctx_handle == CU_CONTEXT_INVALID:
187 _other._get_context()
188 # Compare contexts as well
189 return <uintptr_t>(self._ctx_handle) == <uintptr_t>((_other)._ctx_handle)
191 @property
192 def handle(self) -> cuda.bindings.driver.CUstream:
193 """Return the underlying ``CUstream`` object.
195 .. caution::
197 This handle is a Python object. To get the memory address of the underlying C
198 handle, call ``int(Stream.handle)``.
199 """
200 return driver.CUstream(<uintptr_t>(self._handle))
202 @property
203 def is_nonblocking(self) -> bool:
204 """Return True if this is a nonblocking stream, otherwise False."""
205 cdef unsigned int flags
206 if self._nonblocking == -1:
207 with nogil:
208 HANDLE_RETURN(cydriver.cuStreamGetFlags(self._handle, &flags))
209 if flags & cydriver.CUstream_flags.CU_STREAM_NON_BLOCKING:
210 self._nonblocking = True
211 else:
212 self._nonblocking = False
213 return bool(self._nonblocking)
215 @property
216 def priority(self) -> int:
217 """Return the stream priority."""
218 cdef int prio
219 if self._priority == INT32_MIN:
220 with nogil:
221 HANDLE_RETURN(cydriver.cuStreamGetPriority(self._handle, &prio))
222 self._priority = prio
223 return self._priority
225 def sync(self):
226 """Synchronize the stream."""
227 with nogil:
228 HANDLE_RETURN(cydriver.cuStreamSynchronize(self._handle))
230 def record(self, event: Event = None, options: EventOptions = None) -> Event:
231 """Record an event onto the stream.
233 Creates an Event object (or reuses the given one) by
234 recording on the stream.
236 Parameters
237 ----------
238 event : :obj:`~_event.Event`, optional
239 Optional event object to be reused for recording.
240 options : :obj:`EventOptions`, optional
241 Customizable dataclass for event creation options.
243 Returns
244 -------
245 :obj:`~_event.Event`
246 Newly created event object.
248 """
249 # Create an Event object (or reusing the given one) by recording
250 # on the stream. Event flags such as disabling timing, nonblocking,
251 # and CU_EVENT_RECORD_EXTERNAL, can be set in EventOptions.
252 if event is None:
253 self._get_device_and_context()
254 event = Event._init(<int>(self._device_id), <uintptr_t>(self._ctx_handle), options, False)
255 elif event.is_ipc_enabled:
256 raise TypeError(
257 "IPC-enabled events should not be re-recorded, instead create a "
258 "new event by supplying options."
259 )
261 cdef cydriver.CUevent e = (<cyEvent?>(event))._handle
262 with nogil:
263 HANDLE_RETURN(cydriver.cuEventRecord(e, self._handle))
264 return event
266 def wait(self, event_or_stream: Union[Event, Stream]):
267 """Wait for a CUDA event or a CUDA stream.
269 Waiting for an event or a stream establishes a stream order.
271 If a :obj:`~_stream.Stream` is provided, then wait until the stream's
272 work is completed. This is done by recording a new :obj:`~_event.Event`
273 on the stream and then waiting on it.
275 """
276 cdef cydriver.CUevent event
277 cdef cydriver.CUstream stream
279 if isinstance(event_or_stream, Event):
280 event = <cydriver.CUevent><uintptr_t>(event_or_stream.handle)
281 with nogil:
282 # TODO: support flags other than 0?
283 HANDLE_RETURN(cydriver.cuStreamWaitEvent(self._handle, event, 0))
284 else:
285 if isinstance(event_or_stream, Stream):
286 stream = <cydriver.CUstream><uintptr_t>(event_or_stream.handle)
287 else:
288 try:
289 s = Stream._init(obj=event_or_stream)
290 except Exception as e:
291 raise ValueError(
292 "only an Event, Stream, or object supporting __cuda_stream__ can be waited,"
293 f" got {type(event_or_stream)}"
294 ) from e
295 stream = <cydriver.CUstream><uintptr_t>(s.handle)
296 with nogil:
297 HANDLE_RETURN(cydriver.cuEventCreate(&event, cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
298 HANDLE_RETURN(cydriver.cuEventRecord(event, stream))
299 # TODO: support flags other than 0?
300 HANDLE_RETURN(cydriver.cuStreamWaitEvent(self._handle, event, 0))
301 HANDLE_RETURN(cydriver.cuEventDestroy(event))
303 @property
304 def device(self) -> Device:
305 """Return the :obj:`~_device.Device` singleton associated with this stream.
307 Note
308 ----
309 The current context on the device may differ from this
310 stream's context. This case occurs when a different CUDA
311 context is set current after a stream is created.
313 """
314 from cuda.core.experimental._device import Device # avoid circular import
315 self._get_device_and_context()
316 return Device(<int>(self._device_id))
318 cdef int _get_context(self) except?-1 nogil:
319 if self._ctx_handle == CU_CONTEXT_INVALID:
320 HANDLE_RETURN(cydriver.cuStreamGetCtx(self._handle, &(self._ctx_handle)))
321 return 0
323 cdef int _get_device_and_context(self) except?-1:
324 cdef cydriver.CUcontext curr_ctx
325 if self._device_id == cydriver.CU_DEVICE_INVALID:
326 with nogil:
327 # Get the current context
328 HANDLE_RETURN(cydriver.cuCtxGetCurrent(&curr_ctx))
329 # Get the stream's context (self.ctx_handle is populated)
330 self._get_context()
331 # Get the stream's device (may require a context-switching dance)
332 self._device_id = get_device_from_ctx(self._ctx_handle, curr_ctx)
333 return 0
335 @property
336 def context(self) -> Context:
337 """Return the :obj:`~_context.Context` associated with this stream."""
338 self._get_context()
339 self._get_device_and_context()
340 return Context._from_ctx(<uintptr_t>(self._ctx_handle), <int>(self._device_id))
342 @staticmethod
343 def from_handle(handle: int) -> Stream:
344 """Create a new :obj:`~_stream.Stream` object from a foreign stream handle.
346 Uses a cudaStream_t pointer address represented as a Python int
347 to create a new :obj:`~_stream.Stream` object.
349 Note
350 ----
351 Stream lifetime is not managed, foreign object must remain
352 alive while this steam is active.
354 Parameters
355 ----------
356 handle : int
357 Stream handle representing the address of a foreign
358 stream object.
360 Returns
361 -------
362 :obj:`~_stream.Stream`
363 Newly created stream object.
365 """
367 class _stream_holder:
368 def __cuda_stream__(self):
369 return (0, handle)
371 return Stream._init(obj=_stream_holder())
373 def create_graph_builder(self) -> GraphBuilder:
374 """Create a new :obj:`~_graph.GraphBuilder` object.
376 The new graph builder will be associated with this stream.
378 Returns
379 -------
380 :obj:`~_graph.GraphBuilder`
381 Newly created graph builder object.
383 """
384 return GraphBuilder._init(stream=self, is_stream_owner=False)
387# c-only python objects, not public
388cdef Stream C_LEGACY_DEFAULT_STREAM = Stream._legacy_default()
389cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream._per_thread_default()
391# standard python objects, public
392LEGACY_DEFAULT_STREAM = C_LEGACY_DEFAULT_STREAM
393PER_THREAD_DEFAULT_STREAM = C_PER_THREAD_DEFAULT_STREAM
396cpdef Stream default_stream():
397 """Return the default CUDA :obj:`~_stream.Stream`.
399 The type of default stream returned depends on if the environment
400 variable CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM is set.
402 If set, returns a per-thread default stream. Otherwise returns
403 the legacy stream.
405 """
406 # TODO: flip the default
407 cdef const char* use_ptds_raw = getenv("CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM")
409 cdef int use_ptds = 0
410 if use_ptds_raw != NULL:
411 use_ptds = strtol(use_ptds_raw, NULL, 10)
413 # value is non-zero, including for weird stuff like 123foo
414 if use_ptds:
415 return C_PER_THREAD_DEFAULT_STREAM
416 else:
417 return C_LEGACY_DEFAULT_STREAM
420cdef cydriver.CUstream _handle_from_stream_protocol(obj) except*:
421 if isinstance(obj, Stream):
422 return <cydriver.CUstream><uintptr_t>(obj.handle)
424 try:
425 cuda_stream_attr = obj.__cuda_stream__
426 except AttributeError:
427 raise TypeError(f"{type(obj)} object does not have a '__cuda_stream__' attribute") from None
429 if callable(cuda_stream_attr):
430 info = cuda_stream_attr()
431 else:
432 info = cuda_stream_attr
433 warnings.simplefilter("once", DeprecationWarning)
434 warnings.warn(
435 "Implementing __cuda_stream__ as an attribute is deprecated; it must be implemented as a method",
436 stacklevel=3,
437 category=DeprecationWarning,
438 )
440 try:
441 len_info = len(info)
442 except TypeError as e:
443 raise RuntimeError(f"obj.__cuda_stream__ must return a sequence with 2 elements, got {type(info)}") from e
444 if len_info != 2:
445 raise RuntimeError(f"obj.__cuda_stream__ must return a sequence with 2 elements, got {len_info} elements")
446 if info[0] != 0:
447 raise RuntimeError(
448 f"The first element of the sequence returned by obj.__cuda_stream__ must be 0, got {repr(info[0])}"
449 )
450 return <cydriver.CUstream><uintptr_t>(info[1])
452# Helper for API functions that accept either Stream or GraphBuilder. Performs
453# needed checks and returns the relevant stream.
454cdef Stream Stream_accept(arg, bint allow_stream_protocol=False):
455 if isinstance(arg, Stream):
456 return <Stream>(arg)
457 elif isinstance(arg, GraphBuilder):
458 return <Stream>(arg.stream)
459 elif allow_stream_protocol:
460 try:
461 stream = Stream._init(arg)
462 except:
463 pass
464 else:
465 warnings.warn(
466 "Passing foreign stream objects to this function via the "
467 "stream protocol is deprecated. Convert the object explicitly "
468 "using Stream(obj) instead.",
469 stacklevel=2,
470 category=DeprecationWarning,
471 )
472 return <Stream>(stream)
473 raise TypeError(f"Stream or GraphBuilder expected, got {type(arg).__name__}")