include/cuda/experimental/__launch/host_launch.cuh

File members: include/cuda/experimental/__launch/host_launch.cuh

//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__LAUNCH_HOST_LAUNCH
#define _CUDAX__LAUNCH_HOST_LAUNCH

#include <cuda/__cccl_config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/__functional/reference_wrapper.h>
#include <cuda/std/__type_traits/decay.h>
#include <cuda/std/__utility/forward.h>
#include <cuda/std/tuple>
#include <cuda/stream_ref>

namespace cuda::experimental
{

template <typename _CallablePtr>
void __stream_callback_caller(cudaStream_t, cudaError_t __status, void* __callable_ptr)
{
  auto __casted_callable = static_cast<_CallablePtr>(__callable_ptr);
  if (__status == cudaSuccess)
  {
    (*__casted_callable)();
  }
  delete __casted_callable;
}

template <typename _Callable, typename... _Args>
void host_launch(stream_ref __stream, _Callable __callable, _Args... __args)
{
  static_assert(_CUDA_VSTD::is_invocable_v<_Callable, _Args...>,
                "Callable can't be called with the supplied arguments");
  auto __lambda_ptr = new auto([__callable   = _CUDA_VSTD::move(__callable),
                                __args_tuple = _CUDA_VSTD::make_tuple(_CUDA_VSTD::move(__args)...)]() mutable {
    _CUDA_VSTD::apply(__callable, __args_tuple);
  });

  // We use the callback here to have it execute even on stream error, because it needs to free the above allocation
  _CCCL_TRY_CUDA_API(
    cudaStreamAddCallback,
    "Failed to launch host function",
    __stream.get(),
    __stream_callback_caller<decltype(__lambda_ptr)>,
    static_cast<void*>(__lambda_ptr),
    0);
}

template <typename _CallablePtr>
void __host_func_launcher(void* __callable_ptr)
{
  auto __casted_callable = static_cast<_CallablePtr>(__callable_ptr);
  (*__casted_callable)();
}

template <typename _Callable, typename... _Args>
void host_launch(stream_ref __stream, ::cuda::std::reference_wrapper<_Callable> __callable)
{
  static_assert(_CUDA_VSTD::is_invocable_v<_Callable>, "Callable in reference_wrapper can't take any arguments");
  _CCCL_TRY_CUDA_API(
    cudaLaunchHostFunc,
    "Failed to launch host function",
    __stream.get(),
    __host_func_launcher<_Callable*>,
    _CUDA_VSTD::addressof(__callable.get()));
}

} // namespace cuda::experimental

#endif // !_CUDAX__LAUNCH_HOST_LAUNCH