Coverage for cuda / core / experimental / _stream.pyx: 88%

201 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-10 01:19 +0000

1# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4  

5from __future__ import annotations 

6  

7from libc.stdint cimport uintptr_t, INT32_MIN 

8from libc.stdlib cimport strtol, getenv 

9  

10from cuda.bindings cimport cydriver 

11  

12from cuda.core.experimental._event cimport Event as cyEvent 

13from cuda.core.experimental._utils.cuda_utils cimport ( 

14 check_or_create_options, 

15 CU_CONTEXT_INVALID, 

16 get_device_from_ctx, 

17 HANDLE_RETURN, 

18) 

19  

20import cython 

21import warnings 

22from dataclasses import dataclass 

23from typing import TYPE_CHECKING, Optional, Protocol, Union 

24  

25if TYPE_CHECKING: 

26 import cuda.bindings 

27 from cuda.core.experimental._device import Device 

28from cuda.core.experimental._context import Context 

29from cuda.core.experimental._event import Event, EventOptions 

30from cuda.core.experimental._graph import GraphBuilder 

31from cuda.core.experimental._utils.cuda_utils import ( 

32 driver, 

33) 

34  

35  

36@dataclass 

37cdef class StreamOptions: 

38 """Customizable :obj:`~_stream.Stream` options. 

39  

40 Attributes 

41 ---------- 

42 nonblocking : bool, optional 

43 Stream does not synchronize with the NULL stream. (Default to True) 

44 priority : int, optional 

45 Stream priority where lower number represents a 

46 higher priority. (Default to lowest priority) 

47  

48 """ 

49  

50 nonblocking : cython.bint = True 

51 priority: Optional[int] = None 

52  

53  

54class IsStreamT(Protocol): 

55 def __cuda_stream__(self) -> tuple[int, int]: 

56 """ 

57 For any Python object that is meant to be interpreted as a CUDA stream, the intent 

58 can be communicated by implementing this protocol that returns a 2-tuple: The protocol 

59 version number (currently ``0``) and the address of ``cudaStream_t``. Both values 

60 should be Python `int`. 

61 """ 

62 ... 

63  

64  

65cdef class Stream: 

66 """Represent a queue of GPU operations that are executed in a specific order. 

67  

68 Applications use streams to control the order of execution for 

69 GPU work. Work within a single stream are executed sequentially. 

70 Whereas work across multiple streams can be further controlled 

71 using stream priorities and :obj:`~_event.Event` managements. 

72  

73 Advanced users can utilize default streams for enforce complex 

74 implicit synchronization behaviors. 

75  

76 Directly creating a :obj:`~_stream.Stream` is not supported due to ambiguity. 

77 New streams should instead be created through a :obj:`~_device.Device` 

78 object, or created directly through using an existing handle 

79 using Stream.from_handle(). 

80 """ 

81 def __cinit__(self): 

82 self._handle = <cydriver.CUstream>(NULL) 

83 self._owner = None 

84 self._builtin = False 

85 self._nonblocking = -1 # lazy init'd 

86 self._priority = INT32_MIN # lazy init'd 

87 self._device_id = cydriver.CU_DEVICE_INVALID # lazy init'd 

88 self._ctx_handle = CU_CONTEXT_INVALID # lazy init'd 

89  

90 def __init__(self, *args, **kwargs): 

91 raise RuntimeError( 

92 "Stream objects cannot be instantiated directly. " 

93 "Please use Device APIs (create_stream) or other Stream APIs (from_handle)." 

94 ) 

95  

96 @classmethod 

97 def _legacy_default(cls): 

98 cdef Stream self = Stream.__new__(cls) 

99 self._handle = <cydriver.CUstream>(cydriver.CU_STREAM_LEGACY) 

100 self._builtin = True 

101 return self 

102  

103 @classmethod 

104 def _per_thread_default(cls): 

105 cdef Stream self = Stream.__new__(cls) 

106 self._handle = <cydriver.CUstream>(cydriver.CU_STREAM_PER_THREAD) 

107 self._builtin = True 

108 return self 

109  

110 @classmethod 

111 def _init(cls, obj: IsStreamT | None = None, options=None, device_id: int = None): 

112 cdef Stream self = Stream.__new__(cls) 

113  

114 if obj is not None and options is not None: 

115 raise ValueError("obj and options cannot be both specified") 

116 if obj is not None: 

117 self._handle = _handle_from_stream_protocol(obj) 

118 # TODO: check if obj is created under the current context/device 

119 self._owner = obj 

120 return self 

121  

122 cdef StreamOptions opts = check_or_create_options(StreamOptions, options, "Stream options") 

123 nonblocking = opts.nonblocking 

124 priority = opts.priority 

125  

126 flags = cydriver.CUstream_flags.CU_STREAM_NON_BLOCKING if nonblocking else cydriver.CUstream_flags.CU_STREAM_DEFAULT 

127 # TODO: we might want to consider memoizing high/low per CUDA context and avoid this call 

128 cdef int high, low 

129 with nogil: 

130 HANDLE_RETURN(cydriver.cuCtxGetStreamPriorityRange(&high, &low)) 

131 cdef int prio 

132 if priority is not None: 

133 prio = priority 

134 if not (low <= prio <= high): 

135 raise ValueError(f"{priority=} is out of range {[low, high]}") 

136 else: 

137 prio = high 

138  

139 cdef cydriver.CUstream s 

140 with nogil: 

141 HANDLE_RETURN(cydriver.cuStreamCreateWithPriority(&s, flags, prio)) 

142 self._handle = s 

143 self._nonblocking = int(nonblocking) 

144 self._priority = prio 

145 self._device_id = device_id if device_id is not None else self._device_id 

146 return self 

147  

148 def __dealloc__(self): 

149 self.close() 

150  

151 cpdef close(self): 

152 """Destroy the stream. 

153  

154 Destroy the stream if we own it. Borrowed foreign stream 

155 object will instead have their references released. 

156  

157 """ 

158 if self._owner is None: 

159 if self._handle and not self._builtin: 

160 with nogil: 

161 HANDLE_RETURN(cydriver.cuStreamDestroy(self._handle)) 

162 else: 

163 self._owner = None 

164 self._handle = <cydriver.CUstream>(NULL) 

165  

166 def __cuda_stream__(self) -> tuple[int, int]: 

167 """Return an instance of a __cuda_stream__ protocol.""" 

168 return (0, <uintptr_t>(self._handle)) 

169  

170 def __hash__(self) -> int: 

171 # Ensure context is initialized for hash consistency 

172 if self._ctx_handle == CU_CONTEXT_INVALID: 

173 self._get_context() 

174 return hash((<uintptr_t>(self._ctx_handle), <uintptr_t>(self._handle))) 

175  

176 def __eq__(self, other) -> bool: 

177 if not isinstance(other, Stream): 

178 return NotImplemented 

179 cdef Stream _other = <Stream>other 

180 # Fast path: compare handles first 

181 if <uintptr_t>(self._handle) != <uintptr_t>((_other)._handle): 

182 return False 

183 # Ensure contexts are initialized for both streams 

184 if self._ctx_handle == CU_CONTEXT_INVALID: 

185 self._get_context() 

186 if _other._ctx_handle == CU_CONTEXT_INVALID: 

187 _other._get_context() 

188 # Compare contexts as well 

189 return <uintptr_t>(self._ctx_handle) == <uintptr_t>((_other)._ctx_handle) 

190  

191 @property 

192 def handle(self) -> cuda.bindings.driver.CUstream: 

193 """Return the underlying ``CUstream`` object. 

194  

195 .. caution:: 

196  

197 This handle is a Python object. To get the memory address of the underlying C 

198 handle, call ``int(Stream.handle)``. 

199 """ 

200 return driver.CUstream(<uintptr_t>(self._handle)) 

201  

202 @property 

203 def is_nonblocking(self) -> bool: 

204 """Return True if this is a nonblocking stream, otherwise False.""" 

205 cdef unsigned int flags 

206 if self._nonblocking == -1: 

207 with nogil: 

208 HANDLE_RETURN(cydriver.cuStreamGetFlags(self._handle, &flags)) 

209 if flags & cydriver.CUstream_flags.CU_STREAM_NON_BLOCKING: 

210 self._nonblocking = True 

211 else: 

212 self._nonblocking = False 

213 return bool(self._nonblocking) 

214  

215 @property 

216 def priority(self) -> int: 

217 """Return the stream priority.""" 

218 cdef int prio 

219 if self._priority == INT32_MIN: 

220 with nogil: 

221 HANDLE_RETURN(cydriver.cuStreamGetPriority(self._handle, &prio)) 

222 self._priority = prio 

223 return self._priority 

224  

225 def sync(self): 

226 """Synchronize the stream.""" 

227 with nogil: 

228 HANDLE_RETURN(cydriver.cuStreamSynchronize(self._handle)) 

229  

230 def record(self, event: Event = None, options: EventOptions = None) -> Event: 

231 """Record an event onto the stream. 

232  

233 Creates an Event object (or reuses the given one) by 

234 recording on the stream. 

235  

236 Parameters 

237 ---------- 

238 event : :obj:`~_event.Event`, optional 

239 Optional event object to be reused for recording. 

240 options : :obj:`EventOptions`, optional 

241 Customizable dataclass for event creation options. 

242  

243 Returns 

244 ------- 

245 :obj:`~_event.Event` 

246 Newly created event object. 

247  

248 """ 

249 # Create an Event object (or reusing the given one) by recording 

250 # on the stream. Event flags such as disabling timing, nonblocking, 

251 # and CU_EVENT_RECORD_EXTERNAL, can be set in EventOptions. 

252 if event is None: 

253 self._get_device_and_context() 

254 event = Event._init(<int>(self._device_id), <uintptr_t>(self._ctx_handle), options, False) 

255 elif event.is_ipc_enabled: 

256 raise TypeError( 

257 "IPC-enabled events should not be re-recorded, instead create a " 

258 "new event by supplying options." 

259 ) 

260  

261 cdef cydriver.CUevent e = (<cyEvent?>(event))._handle 

262 with nogil: 

263 HANDLE_RETURN(cydriver.cuEventRecord(e, self._handle)) 

264 return event 

265  

266 def wait(self, event_or_stream: Union[Event, Stream]): 

267 """Wait for a CUDA event or a CUDA stream. 

268  

269 Waiting for an event or a stream establishes a stream order. 

270  

271 If a :obj:`~_stream.Stream` is provided, then wait until the stream's 

272 work is completed. This is done by recording a new :obj:`~_event.Event` 

273 on the stream and then waiting on it. 

274  

275 """ 

276 cdef cydriver.CUevent event 

277 cdef cydriver.CUstream stream 

278  

279 if isinstance(event_or_stream, Event): 

280 event = <cydriver.CUevent><uintptr_t>(event_or_stream.handle) 

281 with nogil: 

282 # TODO: support flags other than 0? 

283 HANDLE_RETURN(cydriver.cuStreamWaitEvent(self._handle, event, 0)) 

284 else: 

285 if isinstance(event_or_stream, Stream): 

286 stream = <cydriver.CUstream><uintptr_t>(event_or_stream.handle) 

287 else: 

288 try: 

289 s = Stream._init(obj=event_or_stream) 

290 except Exception as e: 

291 raise ValueError( 

292 "only an Event, Stream, or object supporting __cuda_stream__ can be waited," 

293 f" got {type(event_or_stream)}" 

294 ) from e 

295 stream = <cydriver.CUstream><uintptr_t>(s.handle) 

296 with nogil: 

297 HANDLE_RETURN(cydriver.cuEventCreate(&event, cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING)) 

298 HANDLE_RETURN(cydriver.cuEventRecord(event, stream)) 

299 # TODO: support flags other than 0? 

300 HANDLE_RETURN(cydriver.cuStreamWaitEvent(self._handle, event, 0)) 

301 HANDLE_RETURN(cydriver.cuEventDestroy(event)) 

302  

303 @property 

304 def device(self) -> Device: 

305 """Return the :obj:`~_device.Device` singleton associated with this stream. 

306  

307 Note 

308 ---- 

309 The current context on the device may differ from this 

310 stream's context. This case occurs when a different CUDA 

311 context is set current after a stream is created. 

312  

313 """ 

314 from cuda.core.experimental._device import Device # avoid circular import 

315 self._get_device_and_context() 

316 return Device(<int>(self._device_id)) 

317  

318 cdef int _get_context(self) except?-1 nogil: 

319 if self._ctx_handle == CU_CONTEXT_INVALID: 

320 HANDLE_RETURN(cydriver.cuStreamGetCtx(self._handle, &(self._ctx_handle))) 

321 return 0 

322  

323 cdef int _get_device_and_context(self) except?-1: 

324 cdef cydriver.CUcontext curr_ctx 

325 if self._device_id == cydriver.CU_DEVICE_INVALID: 

326 with nogil: 

327 # Get the current context 

328 HANDLE_RETURN(cydriver.cuCtxGetCurrent(&curr_ctx)) 

329 # Get the stream's context (self.ctx_handle is populated) 

330 self._get_context() 

331 # Get the stream's device (may require a context-switching dance) 

332 self._device_id = get_device_from_ctx(self._ctx_handle, curr_ctx) 

333 return 0 

334  

335 @property 

336 def context(self) -> Context: 

337 """Return the :obj:`~_context.Context` associated with this stream.""" 

338 self._get_context() 

339 self._get_device_and_context() 

340 return Context._from_ctx(<uintptr_t>(self._ctx_handle), <int>(self._device_id)) 

341  

342 @staticmethod 

343 def from_handle(handle: int) -> Stream: 

344 """Create a new :obj:`~_stream.Stream` object from a foreign stream handle. 

345  

346 Uses a cudaStream_t pointer address represented as a Python int 

347 to create a new :obj:`~_stream.Stream` object. 

348  

349 Note 

350 ---- 

351 Stream lifetime is not managed, foreign object must remain 

352 alive while this steam is active. 

353  

354 Parameters 

355 ---------- 

356 handle : int 

357 Stream handle representing the address of a foreign 

358 stream object. 

359  

360 Returns 

361 ------- 

362 :obj:`~_stream.Stream` 

363 Newly created stream object. 

364  

365 """ 

366  

367 class _stream_holder: 

368 def __cuda_stream__(self): 

369 return (0, handle) 

370  

371 return Stream._init(obj=_stream_holder()) 

372  

373 def create_graph_builder(self) -> GraphBuilder: 

374 """Create a new :obj:`~_graph.GraphBuilder` object. 

375  

376 The new graph builder will be associated with this stream. 

377  

378 Returns 

379 ------- 

380 :obj:`~_graph.GraphBuilder` 

381 Newly created graph builder object. 

382  

383 """ 

384 return GraphBuilder._init(stream=self, is_stream_owner=False) 

385  

386  

387# c-only python objects, not public 

388cdef Stream C_LEGACY_DEFAULT_STREAM = Stream._legacy_default() 

389cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream._per_thread_default() 

390  

391# standard python objects, public 

392LEGACY_DEFAULT_STREAM = C_LEGACY_DEFAULT_STREAM 

393PER_THREAD_DEFAULT_STREAM = C_PER_THREAD_DEFAULT_STREAM 

394  

395  

396cpdef Stream default_stream(): 

397 """Return the default CUDA :obj:`~_stream.Stream`. 

398  

399 The type of default stream returned depends on if the environment 

400 variable CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM is set. 

401  

402 If set, returns a per-thread default stream. Otherwise returns 

403 the legacy stream. 

404  

405 """ 

406 # TODO: flip the default 

407 cdef const char* use_ptds_raw = getenv("CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM") 

408  

409 cdef int use_ptds = 0 

410 if use_ptds_raw != NULL: 

411 use_ptds = strtol(use_ptds_raw, NULL, 10) 

412  

413 # value is non-zero, including for weird stuff like 123foo 

414 if use_ptds: 

415 return C_PER_THREAD_DEFAULT_STREAM 

416 else: 

417 return C_LEGACY_DEFAULT_STREAM 

418  

419  

420cdef cydriver.CUstream _handle_from_stream_protocol(obj) except*: 

421 if isinstance(obj, Stream): 

422 return <cydriver.CUstream><uintptr_t>(obj.handle) 

423  

424 try: 

425 cuda_stream_attr = obj.__cuda_stream__ 

426 except AttributeError: 

427 raise TypeError(f"{type(obj)} object does not have a '__cuda_stream__' attribute") from None 

428  

429 if callable(cuda_stream_attr): 

430 info = cuda_stream_attr() 

431 else: 

432 info = cuda_stream_attr 

433 warnings.simplefilter("once", DeprecationWarning) 

434 warnings.warn( 

435 "Implementing __cuda_stream__ as an attribute is deprecated; it must be implemented as a method", 

436 stacklevel=3, 

437 category=DeprecationWarning, 

438 ) 

439  

440 try: 

441 len_info = len(info) 

442 except TypeError as e: 

443 raise RuntimeError(f"obj.__cuda_stream__ must return a sequence with 2 elements, got {type(info)}") from e 

444 if len_info != 2: 

445 raise RuntimeError(f"obj.__cuda_stream__ must return a sequence with 2 elements, got {len_info} elements") 

446 if info[0] != 0: 

447 raise RuntimeError( 

448 f"The first element of the sequence returned by obj.__cuda_stream__ must be 0, got {repr(info[0])}" 

449 ) 

450 return <cydriver.CUstream><uintptr_t>(info[1]) 

451  

452# Helper for API functions that accept either Stream or GraphBuilder. Performs 

453# needed checks and returns the relevant stream. 

454cdef Stream Stream_accept(arg, bint allow_stream_protocol=False): 

455 if isinstance(arg, Stream): 

456 return <Stream>(arg) 

457 elif isinstance(arg, GraphBuilder): 

458 return <Stream>(arg.stream) 

459 elif allow_stream_protocol: 

460 try: 

461 stream = Stream._init(arg) 

462 except: 

463 pass 

464 else: 

465 warnings.warn( 

466 "Passing foreign stream objects to this function via the " 

467 "stream protocol is deprecated. Convert the object explicitly " 

468 "using Stream(obj) instead.", 

469 stacklevel=2, 

470 category=DeprecationWarning, 

471 ) 

472 return <Stream>(stream) 

473 raise TypeError(f"Stream or GraphBuilder expected, got {type(arg).__name__}")