include/cuda/experimental/__stf/graph/graph_data_interface.cuh
File members: include/cuda/experimental/__stf/graph/graph_data_interface.cuh
//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <cuda/__cccl_config>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <cuda/experimental/__stf/graph/internal/event_types.cuh>
#include <cuda/experimental/__stf/internal/backend_ctx.cuh>
#include <cuda/experimental/__stf/internal/data_interface.cuh>
namespace cuda::experimental::stf
{
template <typename T>
class graph_data_interface : public data_impl_base<T>
{
public:
using base = data_impl_base<T>;
using shape_t = typename base::shape_t;
graph_data_interface(T p)
: base(mv(p))
{}
graph_data_interface(shape_of<T> s)
: base(mv(s))
{}
virtual cudaGraphNode_t graph_data_copy(
cudaMemcpyKind kind,
instance_id_t src_instance_id,
instance_id_t dst_instance_id,
cudaGraph_t graph,
const cudaGraphNode_t* input_nodes,
size_t input_cnt) = 0;
// Returns prereq
void data_copy(backend_ctx_untyped& ctx_,
const data_place& dst_memory_node,
instance_id_t dst_instance_id,
const data_place& src_memory_node,
instance_id_t src_instance_id,
event_list& prereqs) override
{
::std::ignore = src_memory_node;
::std::ignore = dst_memory_node;
assert(src_memory_node != dst_memory_node);
cudaGraph_t graph = ctx_.graph();
size_t graph_epoch = ctx_.epoch();
assert(graph && graph_epoch != size_t(-1));
const ::std::vector<cudaGraphNode_t> nodes = reserved::join_with_graph_nodes(ctx_, prereqs, graph_epoch);
// Let CUDA figure out from pointers
cudaMemcpyKind kind = cudaMemcpyDefault;
cudaGraphNode_t out = graph_data_copy(kind, src_instance_id, dst_instance_id, graph, nodes.data(), nodes.size());
reserved::fork_from_graph_node(ctx_, out, graph, graph_epoch, prereqs, "copy");
}
};
} // end namespace cuda::experimental::stf