include/cuda/experimental/__graph/path_builder.cuh
File members: include/cuda/experimental/__graph/path_builder.cuh
//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#ifndef _CUDAX__GRAPH_PATH_BUILDER
#define _CUDAX__GRAPH_PATH_BUILDER
#include <cuda/std/detail/__config>
#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/std/__exception/cuda_error.h>
#include <cuda/experimental/__graph/concepts.cuh>
#include <cuda/experimental/__graph/graph_builder.cuh>
#include <cuda/experimental/__graph/graph_node_ref.cuh>
#include <cuda/experimental/__stream/stream_ref.cuh>
#include <vector>
#include <cuda_runtime.h>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
namespace cuda::experimental
{
struct path_builder
{
_CCCL_HOST_API explicit path_builder(graph_builder_ref __builder)
: __dev_{__builder.get_device()}
, __graph_{__builder.get()}
{}
path_builder(device_ref __dev, cudaGraph_t __graph)
: __dev_{__dev}
, __graph_{__graph}
{}
#if _CCCL_CTK_AT_LEAST(12, 3)
template <typename _Fn>
_CCCL_HOST_API void legacy_stream_capture(stream_ref __stream, _Fn&& __capture_fn)
{
_CCCL_TRY_CUDA_API(
::cudaStreamBeginCaptureToGraph,
"Failed to begin stream capture",
__stream.get(),
__graph_,
__nodes_.data(),
nullptr,
__nodes_.size(),
cudaStreamCaptureModeGlobal);
__capture_fn(__stream.get());
cudaGraph_t __graph_out = nullptr;
cudaStreamCaptureStatus __capture_status;
const cudaGraphNode_t* __last_captured_node = nullptr;
size_t __num_nodes = 0;
_CCCL_TRY_CUDA_API(
::cudaStreamGetCaptureInfo,
"Failed to get stream capture info",
__stream.get(),
&__capture_status,
nullptr,
nullptr,
&__last_captured_node,
&__num_nodes);
if (__capture_status != cudaStreamCaptureStatusActive)
{
__throw_cuda_error(cudaErrorInvalidValue, "Stream capture no longer active", "cudaStreamGetCaptureInfo");
}
_CCCL_TRY_CUDA_API(::cudaStreamEndCapture, "Failed to end stream capture", __stream.get(), &__graph_out);
assert(__graph_out == __graph_);
assert(__num_nodes == 1);
__nodes_.clear();
__nodes_.push_back(__last_captured_node[0]);
}
#endif // _CCCL_CTK_AT_LEAST(12, 3)
_CCCL_HOST_API void __clear_and_set_dependency_node(cudaGraphNode_t __node)
{
__nodes_.clear(); // Clear existing nodes
__nodes_.push_back(__node);
}
[[nodiscard]] _CCCL_TRIVIAL_HOST_API auto get_dependencies() const noexcept -> _CUDA_VSTD::span<const cudaGraphNode_t>
{
return _CUDA_VSTD::span(__nodes_.data(), __nodes_.size());
}
_CCCL_HOST_API void wait(const path_builder& __other)
{
__nodes_.insert(__nodes_.end(), __other.__nodes_.begin(), __other.__nodes_.end());
}
template <typename... Nodes>
static constexpr bool __all_dependencies = (graph_dependency<Nodes> && ...);
_CCCL_TEMPLATE(typename... Nodes)
_CCCL_REQUIRES(__all_dependencies<Nodes...>)
_CCCL_HOST_API void depends_on(Nodes&&... __nodes)
{
(
[this](auto&& __arg) {
if constexpr (_CUDA_VSTD::is_same_v<_CUDA_VSTD::decay_t<decltype(__arg)>, graph_node_ref>)
{
__nodes_.push_back(__arg.get());
}
else
{
__nodes_.insert(__nodes_.end(), __arg.__nodes_.begin(), __arg.__nodes_.end());
}
}(static_cast<Nodes&&>(__nodes)),
...);
}
[[nodiscard]] _CCCL_TRIVIAL_HOST_API constexpr auto get_graph() const noexcept -> graph_builder_ref
{
return graph_builder_ref(__graph_, __dev_);
}
[[nodiscard]] _CCCL_TRIVIAL_HOST_API constexpr auto get_native_graph_handle() const noexcept -> cudaGraph_t
{
return __graph_;
}
[[nodiscard]] _CCCL_TRIVIAL_HOST_API constexpr auto get_device() const noexcept -> device_ref
{
return __dev_;
}
private:
device_ref __dev_;
cudaGraph_t __graph_;
// TODO should this be a custom class that does inline storage for small counts?
::std::vector<cudaGraphNode_t> __nodes_;
};
template <typename... Nodes>
[[nodiscard]] _CCCL_HOST_API path_builder start_path(graph_builder_ref __gb, Nodes... __nodes)
{
path_builder __pb(__gb);
if constexpr (sizeof...(__nodes) > 0)
{
__pb.depends_on(__nodes...);
}
return __pb;
}
template <typename _FirstNode, typename... _Nodes>
[[nodiscard]] _CCCL_HOST_API path_builder start_path(device_ref __dev, _FirstNode __first_node, _Nodes... __nodes)
{
path_builder __pb(__dev, __first_node.get_native_graph_handle());
__pb.depends_on(__first_node, __nodes...);
return __pb;
}
} // namespace cuda::experimental
#endif // _CUDAX__GRAPH_PATH_BUILDER