include/cuda/experimental/__graph/path_builder.cuh

File members: include/cuda/experimental/__graph/path_builder.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__GRAPH_PATH_BUILDER
#define _CUDAX__GRAPH_PATH_BUILDER

#include <cuda/std/detail/__config>

#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/std/__exception/cuda_error.h>

#include <cuda/experimental/__graph/concepts.cuh>
#include <cuda/experimental/__graph/graph_builder.cuh>
#include <cuda/experimental/__graph/graph_node_ref.cuh>
#include <cuda/experimental/__stream/stream_ref.cuh>

#include <vector>

#include <cuda_runtime.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

namespace cuda::experimental
{

struct path_builder
{
  _CCCL_HOST_API explicit path_builder(graph_builder_ref __builder)
      : __dev_{__builder.get_device()}
      , __graph_{__builder.get()}
  {}

  path_builder(device_ref __dev, cudaGraph_t __graph)
      : __dev_{__dev}
      , __graph_{__graph}
  {}

#if _CCCL_CTK_AT_LEAST(12, 3)
  template <typename _Fn>
  _CCCL_HOST_API void legacy_stream_capture(stream_ref __stream, _Fn&& __capture_fn)
  {
    _CCCL_TRY_CUDA_API(
      ::cudaStreamBeginCaptureToGraph,
      "Failed to begin stream capture",
      __stream.get(),
      __graph_,
      __nodes_.data(),
      nullptr,
      __nodes_.size(),
      cudaStreamCaptureModeGlobal);

    __capture_fn(__stream.get());

    cudaGraph_t __graph_out = nullptr;

    cudaStreamCaptureStatus __capture_status;
    const cudaGraphNode_t* __last_captured_node = nullptr;
    size_t __num_nodes                          = 0;
    _CCCL_TRY_CUDA_API(
      ::cudaStreamGetCaptureInfo,
      "Failed to get stream capture info",
      __stream.get(),
      &__capture_status,
      nullptr,
      nullptr,
      &__last_captured_node,
      &__num_nodes);
    if (__capture_status != cudaStreamCaptureStatusActive)
    {
      __throw_cuda_error(cudaErrorInvalidValue, "Stream capture no longer active", "cudaStreamGetCaptureInfo");
    }
    _CCCL_TRY_CUDA_API(::cudaStreamEndCapture, "Failed to end stream capture", __stream.get(), &__graph_out);
    assert(__graph_out == __graph_);
    assert(__num_nodes == 1);
    __nodes_.clear();
    __nodes_.push_back(__last_captured_node[0]);
  }
#endif // _CCCL_CTK_AT_LEAST(12, 3)

  _CCCL_HOST_API void __clear_and_set_dependency_node(cudaGraphNode_t __node)
  {
    __nodes_.clear(); // Clear existing nodes
    __nodes_.push_back(__node);
  }

  [[nodiscard]] _CCCL_TRIVIAL_HOST_API auto get_dependencies() const noexcept -> _CUDA_VSTD::span<const cudaGraphNode_t>
  {
    return _CUDA_VSTD::span(__nodes_.data(), __nodes_.size());
  }

  _CCCL_HOST_API void wait(const path_builder& __other)
  {
    __nodes_.insert(__nodes_.end(), __other.__nodes_.begin(), __other.__nodes_.end());
  }

  template <typename... Nodes>
  static constexpr bool __all_dependencies = (graph_dependency<Nodes> && ...);

  _CCCL_TEMPLATE(typename... Nodes)
  _CCCL_REQUIRES(__all_dependencies<Nodes...>)
  _CCCL_HOST_API void depends_on(Nodes&&... __nodes)
  {
    (
      [this](auto&& __arg) {
        if constexpr (_CUDA_VSTD::is_same_v<_CUDA_VSTD::decay_t<decltype(__arg)>, graph_node_ref>)
        {
          __nodes_.push_back(__arg.get());
        }
        else
        {
          __nodes_.insert(__nodes_.end(), __arg.__nodes_.begin(), __arg.__nodes_.end());
        }
      }(static_cast<Nodes&&>(__nodes)),
      ...);
  }

  [[nodiscard]] _CCCL_TRIVIAL_HOST_API constexpr auto get_graph() const noexcept -> graph_builder_ref
  {
    return graph_builder_ref(__graph_, __dev_);
  }

  [[nodiscard]] _CCCL_TRIVIAL_HOST_API constexpr auto get_native_graph_handle() const noexcept -> cudaGraph_t
  {
    return __graph_;
  }

  [[nodiscard]] _CCCL_TRIVIAL_HOST_API constexpr auto get_device() const noexcept -> device_ref
  {
    return __dev_;
  }

private:
  device_ref __dev_;
  cudaGraph_t __graph_;
  // TODO should this be a custom class that does inline storage for small counts?
  ::std::vector<cudaGraphNode_t> __nodes_;
};

template <typename... Nodes>
[[nodiscard]] _CCCL_HOST_API path_builder start_path(graph_builder_ref __gb, Nodes... __nodes)
{
  path_builder __pb(__gb);
  if constexpr (sizeof...(__nodes) > 0)
  {
    __pb.depends_on(__nodes...);
  }
  return __pb;
}

template <typename _FirstNode, typename... _Nodes>
[[nodiscard]] _CCCL_HOST_API path_builder start_path(device_ref __dev, _FirstNode __first_node, _Nodes... __nodes)
{
  path_builder __pb(__dev, __first_node.get_native_graph_handle());
  __pb.depends_on(__first_node, __nodes...);
  return __pb;
}

} // namespace cuda::experimental

#endif // _CUDAX__GRAPH_PATH_BUILDER