include/cuda/experimental/__stream/stream_ref.cuh

File members: include/cuda/experimental/__stream/stream_ref.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__STREAM_STREAM_REF
#define _CUDAX__STREAM_STREAM_REF

#include <cuda/std/detail/__config>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda_runtime_api.h>

#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/stream_ref>

#include <cuda/experimental/__detail/config.cuh>
#include <cuda/experimental/__device/all_devices.cuh>
#include <cuda/experimental/__device/logical_device.cuh>
#include <cuda/experimental/__event/timed_event.cuh>
#include <cuda/experimental/__utility/ensure_current_device.cuh>

namespace cuda::experimental
{

namespace detail
{
// 0 is a valid stream in CUDA, so we need some other invalid stream representation
// Can't make it constexpr, because cudaStream_t is a pointer type
static const ::cudaStream_t __invalid_stream = reinterpret_cast<cudaStream_t>(~0ULL);
} // namespace detail

struct stream_ref : ::cuda::stream_ref
{
  stream_ref() = delete;

  _CUDAX_HOST_API constexpr stream_ref(value_type __stream) noexcept
      : ::cuda::stream_ref{__stream}
  {}

  _CUDAX_HOST_API constexpr stream_ref(const ::cuda::stream_ref& __other) noexcept
      : ::cuda::stream_ref(__other)
  {}

  stream_ref(int) = delete;

  stream_ref(_CUDA_VSTD::nullptr_t) = delete;

  _CCCL_NODISCARD _CUDAX_HOST_API event record_event(event::flags __flags = event::flags::none) const
  {
    return event(*this, __flags);
  }

  _CCCL_NODISCARD _CUDAX_HOST_API timed_event record_timed_event(event::flags __flags = event::flags::none) const
  {
    return timed_event(*this, __flags);
  }

  _CUDAX_TRIVIAL_HOST_API void wait() const
  {
    this->::cuda::stream_ref::wait();
  }

  _CUDAX_HOST_API void wait(event_ref __ev) const
  {
    _CCCL_ASSERT(__ev.get() != nullptr, "cuda::experimental::stream_ref::wait invalid event passed");
    // Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty
    detail::driver::streamWaitEvent(get(), __ev.get());
  }

  _CUDAX_HOST_API void wait(stream_ref __other) const
  {
    // TODO consider an optimization to not create an event every time and instead have one persistent event or one
    // per stream
    _CCCL_ASSERT(__stream != detail::__invalid_stream, "cuda::experimental::stream_ref::wait invalid stream passed");
    if (*this != __other)
    {
      event __tmp(__other);
      wait(__tmp);
    }
  }

  _CUDAX_HOST_API logical_device logical_device() const
  {
    CUcontext __stream_ctx;
    ::cuda::experimental::logical_device::kinds __ctx_kind = ::cuda::experimental::logical_device::kinds::device;
#if CUDART_VERSION >= 12050
    if (detail::driver::getVersion() >= 12050)
    {
      auto __ctx = detail::driver::streamGetCtx_v2(__stream);
      if (__ctx.__ctx_kind == detail::driver::__ctx_from_stream::__kind::__green)
      {
        __stream_ctx = detail::driver::ctxFromGreenCtx(__ctx.__ctx_ptr.__green);
        __ctx_kind   = ::cuda::experimental::logical_device::kinds::green_context;
      }
      else
      {
        __stream_ctx = __ctx.__ctx_ptr.__device;
        __ctx_kind   = ::cuda::experimental::logical_device::kinds::device;
      }
    }
    else
#endif // CUDART_VERSION >= 12050
    {
      __stream_ctx = detail::driver::streamGetCtx(__stream);
      __ctx_kind   = ::cuda::experimental::logical_device::kinds::device;
    }
    // Because the stream can come from_native_handle, we can't just loop over devices comparing contexts,
    // lower to CUDART for this instead
    __ensure_current_device __setter(__stream_ctx);
    int __id;
    _CCCL_TRY_CUDA_API(cudaGetDevice, "Could not get device from a stream", &__id);
    return __logical_device_access::make_logical_device(__id, __stream_ctx, __ctx_kind);
  }

  _CUDAX_HOST_API device_ref device() const
  {
    return logical_device().get_underlying_device();
  }
};

} // namespace cuda::experimental

#endif // _CUDAX__STREAM_STREAM_REF