include/cuda/experimental/__stream/stream_ref.cuh
File members: include/cuda/experimental/__stream/stream_ref.cuh
//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#ifndef _CUDAX__STREAM_STREAM_REF
#define _CUDAX__STREAM_STREAM_REF
#include <cuda/std/detail/__config>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/stream_ref>
#include <cuda/experimental/__device/all_devices.cuh>
#include <cuda/experimental/__device/logical_device.cuh>
#include <cuda/experimental/__event/timed_event.cuh>
#include <cuda/experimental/__execution/fwd.cuh>
#include <cuda/experimental/__utility/ensure_current_device.cuh>
#include <cuda_runtime_api.h>
#include <cuda/std/__cccl/prologue.h>
namespace cuda::experimental
{
namespace __detail
{
// 0 is a valid stream in CUDA, so we need some other invalid stream representation
// Can't make it constexpr, because cudaStream_t is a pointer type
static const ::cudaStream_t __invalid_stream = reinterpret_cast<cudaStream_t>(~0ULL);
} // namespace __detail
struct stream_ref : ::cuda::stream_ref
{
stream_ref() = delete;
_CCCL_HOST_API constexpr stream_ref(value_type __stream) noexcept
: ::cuda::stream_ref{__stream}
{}
_CCCL_HOST_API constexpr stream_ref(const ::cuda::stream_ref& __other) noexcept
: ::cuda::stream_ref(__other)
{}
stream_ref(int) = delete;
stream_ref(_CUDA_VSTD::nullptr_t) = delete;
[[nodiscard]] bool is_done() const
{
const auto __result = __detail::driver::streamQuery(__stream);
switch (__result)
{
case ::cudaErrorNotReady:
return false;
case ::cudaSuccess:
return true;
default:
::cuda::__throw_cuda_error(__result, "Failed to query stream.");
}
}
[[deprecated("Use is_done() instead.")]] [[nodiscard]] bool ready() const
{
return is_done();
}
[[nodiscard]] _CCCL_HOST_API int priority() const
{
return __detail::driver::streamGetPriority(__stream);
}
[[nodiscard]] _CCCL_HOST_API event record_event(event::flags __flags = event::flags::none) const
{
return event(*this, __flags);
}
[[nodiscard]] _CCCL_HOST_API timed_event record_timed_event(event::flags __flags = event::flags::none) const
{
return timed_event(*this, __flags);
}
_CCCL_HOST_API void sync() const
{
__detail::driver::streamSynchronize(__stream);
}
_CCCL_HOST_API void wait(event_ref __ev) const
{
_CCCL_ASSERT(__ev.get() != nullptr, "cuda::experimental::stream_ref::wait invalid event passed");
// Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty
__detail::driver::streamWaitEvent(get(), __ev.get());
}
_CCCL_HOST_API auto schedule() const noexcept;
_CCCL_HOST_API void wait(stream_ref __other) const
{
// TODO consider an optimization to not create an event every time and instead have one persistent event or one
// per stream
_CCCL_ASSERT(__stream != __detail::__invalid_stream, "cuda::experimental::stream_ref::wait invalid stream passed");
if (*this != __other)
{
event __tmp(__other);
wait(__tmp);
}
}
_CCCL_HOST_API logical_device logical_device() const
{
CUcontext __stream_ctx;
::cuda::experimental::logical_device::kinds __ctx_kind = ::cuda::experimental::logical_device::kinds::device;
#if CUDART_VERSION >= 12050
if (__detail::driver::getVersion() >= 12050)
{
auto __ctx = __detail::driver::streamGetCtx_v2(__stream);
if (__ctx.__ctx_kind == __detail::driver::__ctx_from_stream::__kind::__green)
{
__stream_ctx = __detail::driver::ctxFromGreenCtx(__ctx.__ctx_ptr.__green);
__ctx_kind = ::cuda::experimental::logical_device::kinds::green_context;
}
else
{
__stream_ctx = __ctx.__ctx_ptr.__device;
__ctx_kind = ::cuda::experimental::logical_device::kinds::device;
}
}
else
#endif // CUDART_VERSION >= 12050
{
__stream_ctx = __detail::driver::streamGetCtx(__stream);
__ctx_kind = ::cuda::experimental::logical_device::kinds::device;
}
// Because the stream can come from_native_handle, we can't just loop over devices comparing contexts,
// lower to CUDART for this instead
__ensure_current_device __setter(__stream_ctx);
int __id;
_CCCL_TRY_CUDA_API(cudaGetDevice, "Could not get device from a stream", &__id);
return __logical_device_access::make_logical_device(__id, __stream_ctx, __ctx_kind);
}
_CCCL_HOST_API device_ref device() const
{
return logical_device().underlying_device();
}
[[nodiscard]] _CCCL_API constexpr auto query(const get_stream_t&) const noexcept -> stream_ref
{
return *this;
}
[[nodiscard]] _CCCL_API static constexpr auto query(const execution::get_forward_progress_guarantee_t&) noexcept
-> execution::forward_progress_guarantee
{
return execution::forward_progress_guarantee::weakly_parallel;
}
[[nodiscard]] _CCCL_API static constexpr auto query(const execution::get_domain_t&) noexcept
-> execution::stream_domain;
};
} // namespace cuda::experimental
#include <cuda/std/__cccl/epilogue.h>
#endif // _CUDAX__STREAM_STREAM_REF