include/cuda/experimental/__event/event_ref.cuh

File members: include/cuda/experimental/__event/event_ref.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX_EVENT_REF_DETAIL_H
#define _CUDAX_EVENT_REF_DETAIL_H

#include <cuda_runtime_api.h>
// cuda_runtime_api needs to come first

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/std/cassert>
#include <cuda/std/cstddef>
#include <cuda/std/utility>
#include <cuda/stream_ref>

#include <cuda/experimental/__utility/driver_api.cuh>

namespace cuda::experimental
{
class event;
class timed_event;

class event_ref
{
private:
  friend class event;
  friend class timed_event;

  ::cudaEvent_t __event_{};

public:
  using value_type = ::cudaEvent_t;

  constexpr event_ref(::cudaEvent_t __evnt) noexcept
      : __event_(__evnt)
  {}

  event_ref(int) = delete;

  event_ref(_CUDA_VSTD::nullptr_t) = delete;

  void record(stream_ref __stream) const
  {
    _CCCL_ASSERT(__event_ != nullptr, "cuda::experimental::event_ref::record no event set");
    _CCCL_ASSERT(__stream.get() != nullptr, "cuda::experimental::event_ref::record invalid stream passed");
    // Need to use driver API, cudaEventRecord will push dev 0 if stack is empty
    detail::driver::eventRecord(__event_, __stream.get());
  }

  void wait() const
  {
    _CCCL_ASSERT(__event_ != nullptr, "cuda::experimental::event_ref::wait no event set");
    _CCCL_TRY_CUDA_API(::cudaEventSynchronize, "Failed to wait for CUDA event", __event_);
  }

  _CCCL_NODISCARD bool is_done() const
  {
    _CCCL_ASSERT(__event_ != nullptr, "cuda::experimental::event_ref::wait no event set");
    cudaError_t __status = ::cudaEventQuery(__event_);
    if (__status == cudaSuccess)
    {
      return true;
    }
    else if (__status == cudaErrorNotReady)
    {
      return false;
    }
    else
    {
      ::cuda::__throw_cuda_error(__status, "Failed to query CUDA event");
    }
  }

  _CCCL_NODISCARD constexpr ::cudaEvent_t get() const noexcept
  {
    return __event_;
  }

  _CCCL_NODISCARD explicit constexpr operator bool() const noexcept
  {
    return __event_ != nullptr;
  }

  _CCCL_NODISCARD_FRIEND constexpr bool operator==(event_ref __lhs, event_ref __rhs) noexcept
  {
    return __lhs.__event_ == __rhs.__event_;
  }

  _CCCL_NODISCARD_FRIEND constexpr bool operator!=(event_ref __lhs, event_ref __rhs) noexcept
  {
    return __lhs.__event_ != __rhs.__event_;
  }
};
} // namespace cuda::experimental

#endif // _CUDAX_EVENT_REF_DETAIL_H