include/cuda/experimental/__stream/stream_ref.cuh

File members: include/cuda/experimental/__stream/stream_ref.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__STREAM_STREAM_REF
#define _CUDAX__STREAM_STREAM_REF

#include <cuda/std/detail/__config>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/stream_ref>

#include <cuda/experimental/__device/all_devices.cuh>
#include <cuda/experimental/__device/logical_device.cuh>
#include <cuda/experimental/__event/timed_event.cuh>
#include <cuda/experimental/__execution/fwd.cuh>
#include <cuda/experimental/__utility/ensure_current_device.cuh>

#include <cuda_runtime_api.h>

#include <cuda/std/__cccl/prologue.h>

namespace cuda::experimental
{

namespace __detail
{
// 0 is a valid stream in CUDA, so we need some other invalid stream representation
// Can't make it constexpr, because cudaStream_t is a pointer type
static const ::cudaStream_t __invalid_stream = reinterpret_cast<cudaStream_t>(~0ULL);
} // namespace __detail

struct stream_ref : ::cuda::stream_ref
{
  stream_ref() = delete;

  _CCCL_HOST_API constexpr stream_ref(value_type __stream) noexcept
      : ::cuda::stream_ref{__stream}
  {}

  _CCCL_HOST_API constexpr stream_ref(const ::cuda::stream_ref& __other) noexcept
      : ::cuda::stream_ref(__other)
  {}

  stream_ref(int) = delete;

  stream_ref(_CUDA_VSTD::nullptr_t) = delete;

  [[nodiscard]] bool is_done() const
  {
    const auto __result = __detail::driver::streamQuery(__stream);
    switch (__result)
    {
      case ::cudaErrorNotReady:
        return false;
      case ::cudaSuccess:
        return true;
      default:
        ::cuda::__throw_cuda_error(__result, "Failed to query stream.");
    }
  }

  [[deprecated("Use is_done() instead.")]] [[nodiscard]] bool ready() const
  {
    return is_done();
  }

  [[nodiscard]] _CCCL_HOST_API int priority() const
  {
    return __detail::driver::streamGetPriority(__stream);
  }

  [[nodiscard]] _CCCL_HOST_API event record_event(event::flags __flags = event::flags::none) const
  {
    return event(*this, __flags);
  }

  [[nodiscard]] _CCCL_HOST_API timed_event record_timed_event(event::flags __flags = event::flags::none) const
  {
    return timed_event(*this, __flags);
  }

  _CCCL_HOST_API void sync() const
  {
    __detail::driver::streamSynchronize(__stream);
  }

  _CCCL_HOST_API void wait(event_ref __ev) const
  {
    _CCCL_ASSERT(__ev.get() != nullptr, "cuda::experimental::stream_ref::wait invalid event passed");
    // Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty
    __detail::driver::streamWaitEvent(get(), __ev.get());
  }

  _CCCL_HOST_API auto schedule() const noexcept;

  _CCCL_HOST_API void wait(stream_ref __other) const
  {
    // TODO consider an optimization to not create an event every time and instead have one persistent event or one
    // per stream
    _CCCL_ASSERT(__stream != __detail::__invalid_stream, "cuda::experimental::stream_ref::wait invalid stream passed");
    if (*this != __other)
    {
      event __tmp(__other);
      wait(__tmp);
    }
  }

  _CCCL_HOST_API logical_device logical_device() const
  {
    CUcontext __stream_ctx;
    ::cuda::experimental::logical_device::kinds __ctx_kind = ::cuda::experimental::logical_device::kinds::device;
#if CUDART_VERSION >= 12050
    if (__detail::driver::getVersion() >= 12050)
    {
      auto __ctx = __detail::driver::streamGetCtx_v2(__stream);
      if (__ctx.__ctx_kind == __detail::driver::__ctx_from_stream::__kind::__green)
      {
        __stream_ctx = __detail::driver::ctxFromGreenCtx(__ctx.__ctx_ptr.__green);
        __ctx_kind   = ::cuda::experimental::logical_device::kinds::green_context;
      }
      else
      {
        __stream_ctx = __ctx.__ctx_ptr.__device;
        __ctx_kind   = ::cuda::experimental::logical_device::kinds::device;
      }
    }
    else
#endif // CUDART_VERSION >= 12050
    {
      __stream_ctx = __detail::driver::streamGetCtx(__stream);
      __ctx_kind   = ::cuda::experimental::logical_device::kinds::device;
    }
    // Because the stream can come from_native_handle, we can't just loop over devices comparing contexts,
    // lower to CUDART for this instead
    __ensure_current_device __setter(__stream_ctx);
    int __id;
    _CCCL_TRY_CUDA_API(cudaGetDevice, "Could not get device from a stream", &__id);
    return __logical_device_access::make_logical_device(__id, __stream_ctx, __ctx_kind);
  }

  _CCCL_HOST_API device_ref device() const
  {
    return logical_device().underlying_device();
  }

  [[nodiscard]] _CCCL_API constexpr auto query(const get_stream_t&) const noexcept -> stream_ref
  {
    return *this;
  }

  [[nodiscard]] _CCCL_API static constexpr auto query(const execution::get_forward_progress_guarantee_t&) noexcept
    -> execution::forward_progress_guarantee
  {
    return execution::forward_progress_guarantee::weakly_parallel;
  }

  [[nodiscard]] _CCCL_API static constexpr auto query(const execution::get_domain_t&) noexcept
    -> execution::stream_domain;
};

} // namespace cuda::experimental

#include <cuda/std/__cccl/epilogue.h>

#endif // _CUDAX__STREAM_STREAM_REF