cub/device/device_transform.cuh
File members: cub/device/device_transform.cuh
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#pragma once
#include <cub/config.cuh>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <cub/detail/choose_offset.cuh>
#include <cub/detail/nvtx.cuh>
#include <cub/device/dispatch/dispatch_transform.cuh>
#include <cub/util_namespace.cuh>
#include <cuda/std/tuple>
CUB_NAMESPACE_BEGIN
struct DeviceTransform
{
template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream = nullptr)
{
CUB_DETAIL_NVTX_RANGE_SCOPE("cub::DeviceTransform::Transform");
using choose_offset_t = detail::choose_signed_offset<NumItemsT>;
using offset_t = typename choose_offset_t::type;
// Check if the number of items exceeds the range covered by the selected signed offset type
cudaError_t error = choose_offset_t::is_exceeding_offset_type(num_items);
if (error)
{
return error;
}
return detail::transform::
dispatch_t<false, offset_t, ::cuda::std::tuple<RandomAccessIteratorsIn...>, RandomAccessIteratorOut, TransformOp>::
dispatch(
::cuda::std::move(inputs), ::cuda::std::move(output), num_items, ::cuda::std::move(transform_op), stream);
}
#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// This overload has additional parameters to specify temporary storage. Provided for compatibility with other CUB
// APIs.
template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
void* d_temp_storage,
size_t& temp_storage_bytes,
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream = nullptr)
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
return cudaSuccess;
}
return Transform(
::cuda::std::move(inputs), ::cuda::std::move(output), num_items, ::cuda::std::move(transform_op), stream);
}
#endif // _CCCL_DOXYGEN_INVOKED
template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream = nullptr)
{
return Transform(
::cuda::std::make_tuple(::cuda::std::move(input)),
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
stream);
}
#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// This overload has additional parameters to specify temporary storage. Provided for compatibility with other CUB
// APIs.
template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
void* d_temp_storage,
size_t& temp_storage_bytes,
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream = nullptr)
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
return cudaSuccess;
}
return Transform(
::cuda::std::make_tuple(::cuda::std::move(input)),
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
stream);
}
#endif // _CCCL_DOXYGEN_INVOKED
template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream = nullptr)
{
CUB_DETAIL_NVTX_RANGE_SCOPE("cub::DeviceTransform::TransformStableArgumentAddresses");
using choose_offset_t = detail::choose_signed_offset<NumItemsT>;
using offset_t = typename choose_offset_t::type;
// Check if the number of items exceeds the range covered by the selected signed offset type
cudaError_t error = choose_offset_t::is_exceeding_offset_type(num_items);
if (error)
{
return error;
}
return detail::transform::
dispatch_t<true, offset_t, ::cuda::std::tuple<RandomAccessIteratorsIn...>, RandomAccessIteratorOut, TransformOp>::
dispatch(
::cuda::std::move(inputs), ::cuda::std::move(output), num_items, ::cuda::std::move(transform_op), stream);
}
#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
void* d_temp_storage,
size_t& temp_storage_bytes,
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream = nullptr)
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
return cudaSuccess;
}
return TransformStableArgumentAddresses(
::cuda::std::move(inputs), ::cuda::std::move(output), num_items, ::cuda::std::move(transform_op), stream);
}
#endif // _CCCL_DOXYGEN_INVOKED
template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream = nullptr)
{
return TransformStableArgumentAddresses(
::cuda::std::make_tuple(::cuda::std::move(input)),
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
stream);
}
#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
void* d_temp_storage,
size_t& temp_storage_bytes,
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream = nullptr)
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
return cudaSuccess;
}
return TransformStableArgumentAddresses(
::cuda::std::make_tuple(::cuda::std::move(input)),
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
stream);
}
#endif // _CCCL_DOXYGEN_INVOKED
};
CUB_NAMESPACE_END