cub/device/device_transform.cuh

File members: cub/device/device_transform.cuh

// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/detail/choose_offset.cuh>
#include <cub/detail/nvtx.cuh>
#include <cub/device/dispatch/dispatch_transform.cuh>
#include <cub/util_namespace.cuh>

#include <cuda/std/tuple>

CUB_NAMESPACE_BEGIN

struct DeviceTransform
{
  template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
  CUB_RUNTIME_FUNCTION static cudaError_t Transform(
    ::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
    RandomAccessIteratorOut output,
    NumItemsT num_items,
    TransformOp transform_op,
    cudaStream_t stream = nullptr)
  {
    CUB_DETAIL_NVTX_RANGE_SCOPE("cub::DeviceTransform::Transform");

    using choose_offset_t = detail::choose_signed_offset<NumItemsT>;
    using offset_t        = typename choose_offset_t::type;

    // Check if the number of items exceeds the range covered by the selected signed offset type
    cudaError_t error = choose_offset_t::is_exceeding_offset_type(num_items);
    if (error)
    {
      return error;
    }

    return detail::transform::
      dispatch_t<false, offset_t, ::cuda::std::tuple<RandomAccessIteratorsIn...>, RandomAccessIteratorOut, TransformOp>::
        dispatch(
          ::cuda::std::move(inputs), ::cuda::std::move(output), num_items, ::cuda::std::move(transform_op), stream);
  }

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
  // This overload has additional parameters to specify temporary storage. Provided for compatibility with other CUB
  // APIs.
  template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
  CUB_RUNTIME_FUNCTION static cudaError_t Transform(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    ::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
    RandomAccessIteratorOut output,
    NumItemsT num_items,
    TransformOp transform_op,
    cudaStream_t stream = nullptr)
  {
    if (d_temp_storage == nullptr)
    {
      temp_storage_bytes = 1;
      return cudaSuccess;
    }

    return Transform(
      ::cuda::std::move(inputs), ::cuda::std::move(output), num_items, ::cuda::std::move(transform_op), stream);
  }
#endif // _CCCL_DOXYGEN_INVOKED

  template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
  CUB_RUNTIME_FUNCTION static cudaError_t Transform(
    RandomAccessIteratorIn input,
    RandomAccessIteratorOut output,
    NumItemsT num_items,
    TransformOp transform_op,
    cudaStream_t stream = nullptr)
  {
    return Transform(
      ::cuda::std::make_tuple(::cuda::std::move(input)),
      ::cuda::std::move(output),
      num_items,
      ::cuda::std::move(transform_op),
      stream);
  }

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
  // This overload has additional parameters to specify temporary storage. Provided for compatibility with other CUB
  // APIs.
  template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
  CUB_RUNTIME_FUNCTION static cudaError_t Transform(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    RandomAccessIteratorIn input,
    RandomAccessIteratorOut output,
    NumItemsT num_items,
    TransformOp transform_op,
    cudaStream_t stream = nullptr)
  {
    if (d_temp_storage == nullptr)
    {
      temp_storage_bytes = 1;
      return cudaSuccess;
    }

    return Transform(
      ::cuda::std::make_tuple(::cuda::std::move(input)),
      ::cuda::std::move(output),
      num_items,
      ::cuda::std::move(transform_op),
      stream);
  }
#endif // _CCCL_DOXYGEN_INVOKED

  template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
  CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
    ::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
    RandomAccessIteratorOut output,
    NumItemsT num_items,
    TransformOp transform_op,
    cudaStream_t stream = nullptr)
  {
    CUB_DETAIL_NVTX_RANGE_SCOPE("cub::DeviceTransform::TransformStableArgumentAddresses");

    using choose_offset_t = detail::choose_signed_offset<NumItemsT>;
    using offset_t        = typename choose_offset_t::type;

    // Check if the number of items exceeds the range covered by the selected signed offset type
    cudaError_t error = choose_offset_t::is_exceeding_offset_type(num_items);
    if (error)
    {
      return error;
    }

    return detail::transform::
      dispatch_t<true, offset_t, ::cuda::std::tuple<RandomAccessIteratorsIn...>, RandomAccessIteratorOut, TransformOp>::
        dispatch(
          ::cuda::std::move(inputs), ::cuda::std::move(output), num_items, ::cuda::std::move(transform_op), stream);
  }

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
  template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
  CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    ::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
    RandomAccessIteratorOut output,
    NumItemsT num_items,
    TransformOp transform_op,
    cudaStream_t stream = nullptr)
  {
    if (d_temp_storage == nullptr)
    {
      temp_storage_bytes = 1;
      return cudaSuccess;
    }

    return TransformStableArgumentAddresses(
      ::cuda::std::move(inputs), ::cuda::std::move(output), num_items, ::cuda::std::move(transform_op), stream);
  }
#endif // _CCCL_DOXYGEN_INVOKED

  template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
  CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
    RandomAccessIteratorIn input,
    RandomAccessIteratorOut output,
    NumItemsT num_items,
    TransformOp transform_op,
    cudaStream_t stream = nullptr)
  {
    return TransformStableArgumentAddresses(
      ::cuda::std::make_tuple(::cuda::std::move(input)),
      ::cuda::std::move(output),
      num_items,
      ::cuda::std::move(transform_op),
      stream);
  }

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
  template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
  CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    RandomAccessIteratorIn input,
    RandomAccessIteratorOut output,
    NumItemsT num_items,
    TransformOp transform_op,
    cudaStream_t stream = nullptr)
  {
    if (d_temp_storage == nullptr)
    {
      temp_storage_bytes = 1;
      return cudaSuccess;
    }

    return TransformStableArgumentAddresses(
      ::cuda::std::make_tuple(::cuda::std::move(input)),
      ::cuda::std::move(output),
      num_items,
      ::cuda::std::move(transform_op),
      stream);
  }
#endif // _CCCL_DOXYGEN_INVOKED
};

CUB_NAMESPACE_END