cub/thread/thread_reduce.cuh

File members: cub/thread/thread_reduce.cuh

/******************************************************************************
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/detail/array_utils.cuh> // to_array()
#include <cub/detail/type_traits.cuh> // are_same()
#include <cub/thread/thread_load.cuh> // UnrolledCopy
#include <cub/thread/thread_operators.cuh> // cub_operator_to_dpx_t
#include <cub/util_namespace.cuh>

#include <cuda/functional> // cuda::std::maximum
#include <cuda/std/array> // array
#include <cuda/std/bit> // bit_cast
#include <cuda/std/cassert> // assert
#include <cuda/std/cstdint> // uint16_t
#include <cuda/std/functional> // cuda::std::plus

#if defined(_CCCL_HAS_NVFP16)
#  include <cuda_fp16.h>
#endif // _CCCL_HAS_NVFP16

#if defined(_CCCL_HAS_NVBF16)
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
#  include <cuda_bf16.h>
_CCCL_DIAG_POP
#endif // _CCCL_HAS_NVFP16

CUB_NAMESPACE_BEGIN

template <typename Input,
          typename ReductionOp,
#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
          typename ValueT = ::cuda::std::remove_cvref_t<decltype(::cuda::std::declval<Input>()[0])>,
#else
          typename ValueT = random_access_value_t<Input>,
#endif // !_CCCL_DOXYGEN_INVOKED
          typename AccumT = ::cuda::std::__accumulator_t<ReductionOp, ValueT>>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduce(const Input& input, ReductionOp reduction_op);
// forward declaration

/***********************************************************************************************************************
 * Internal Reduction Implementations
 **********************************************************************************************************************/

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document

namespace detail
{

// NOTE: bit_cast cannot be always used because __half, __nv_bfloat16, etc. are not trivially copyable
template <typename Output, typename Input>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE Output unsafe_bitcast(const Input& input)
{
  Output output;
  static_assert(sizeof(input) == sizeof(output), "wrong size");
  ::memcpy(&output, &input, sizeof(input));
  return output;
}

} // namespace detail

namespace internal
{

/***********************************************************************************************************************
 * Enable SIMD/Tree reduction heuristics
 **********************************************************************************************************************/

// TODO: add Blackwell support

template <typename T, typename ReductionOp>
struct enable_generic_simd_reduction_traits
{
  static constexpr bool value =
    cub::detail::is_one_of<T, ::cuda::std::int16_t, ::cuda::std::uint16_t>()
    && cub::detail::
      is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::minimum<T>, ::cuda::maximum<>, ::cuda::maximum<T>>();
};

#  if defined(_CCCL_HAS_NVFP16)

template <typename ReductionOp>
struct enable_generic_simd_reduction_traits<__half, ReductionOp>
{
  static constexpr bool value = cub::detail::is_one_of<
    ReductionOp,
    ::cuda::minimum<>,
    ::cuda::minimum<__half>,
    ::cuda::maximum<>,
    ::cuda::maximum<__half>,
    ::cuda::std::plus<>,
    ::cuda::std::plus<__half>,
    ::cuda::std::multiplies<>,
    ::cuda::std::multiplies<__half>>();
};
#  endif // defined(_CCCL_HAS_NVFP16)

#  if defined(_CCCL_HAS_NVBF16)

template <typename ReductionOp>
struct enable_generic_simd_reduction_traits<__nv_bfloat16, ReductionOp>
{
  static constexpr bool value = cub::detail::is_one_of<
    ReductionOp,
    ::cuda::minimum<>,
    ::cuda::minimum<__nv_bfloat16>,
    ::cuda::maximum<>,
    ::cuda::maximum<__nv_bfloat16>,
    ::cuda::std::plus<>,
    ::cuda::std::plus<__nv_bfloat16>,
    ::cuda::std::multiplies<>,
    ::cuda::std::multiplies<__nv_bfloat16>>();
};

#  endif // defined(_CCCL_HAS_NVBF16)

template <typename Input, typename ReductionOp>
_CCCL_NODISCARD _CCCL_DEVICE constexpr bool enable_generic_simd_reduction()
{
  using cub::detail::is_one_of;
  using T      = ::cuda::std::remove_cvref_t<decltype(::cuda::std::declval<Input>()[0])>;
  using Length = ::cuda::std::integral_constant<int, cub::detail::static_size_v<Input>()>;
  return enable_generic_simd_reduction_traits<T, ReductionOp>::value && Length{} >= 4;
}

template <typename T, typename ReductionOp, int Length>
_CCCL_NODISCARD _CCCL_DEVICE constexpr bool enable_sm90_simd_reduction()
{
  using cub::detail::is_one_of;
  // ::cuda::std::plus<> not handled: IADD3 always produces less instructions than VIADD2
  return is_one_of<T, ::cuda::std::int16_t, ::cuda::std::uint16_t>() && //
         is_one_of<ReductionOp, ::cuda::minimum<>, ::cuda::minimum<T>, ::cuda::maximum<>, ::cuda::maximum<T>>()
      && Length >= 10;
}

template <typename T, typename ReductionOp, int Length>
_CCCL_NODISCARD _CCCL_DEVICE constexpr bool enable_sm80_simd_reduction()
{
  using cub::detail::is_one_of;
  using ::cuda::std::is_same;
  return is_one_of<ReductionOp,
                   ::cuda::minimum<>,
                   ::cuda::minimum<T>,
                   ::cuda::maximum<>,
                   ::cuda::maximum<T>,
                   ::cuda::std::plus<>,
                   ::cuda::std::plus<T>,
                   ::cuda::std::multiplies<>,
                   ::cuda::std::multiplies<T>>()
      && Length >= 4
#  if defined(_CCCL_HAS_NVFP16) && defined(_CCCL_HAS_NVBF16)
      && (is_same<T, __half>::value || is_same<T, __nv_bfloat16>::value)
#  elif defined(_CCCL_HAS_NVFP16)
      && is_same<T, __half>::value
#  elif defined(_CCCL_HAS_NVBF16)
      && is_same<T, __nv_bfloat16>::value
#  endif
    ;
}

template <typename T, typename ReductionOp, int Length>
_CCCL_NODISCARD _CCCL_DEVICE constexpr bool enable_sm70_simd_reduction()
{
  using cub::detail::is_one_of;
  using ::cuda::std::is_same;
#  if defined(_CCCL_HAS_NVFP16)
  return is_same<T, __half>::value
      && is_one_of<ReductionOp,
                   ::cuda::std::plus<>,
                   ::cuda::std::plus<T>,
                   ::cuda::std::multiplies<>,
                   ::cuda::std::multiplies<T>>()
      && Length >= 4;
#  else
  return false;
#  endif
}

template <typename Input, typename ReductionOp, typename AccumT>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE _CCCL_CONSTEXPR_CXX14 bool enable_simd_reduction()
{
  using cub::detail::is_one_of;
  using ::cuda::std::is_same;
  using T = ::cuda::std::remove_cvref_t<decltype(::cuda::std::declval<Input>()[0])>;
  _CCCL_IF_CONSTEXPR (!is_same<T, AccumT>::value)
  {
    return false;
  }
  else
  {
    constexpr auto length = cub::detail::static_size_v<Input>();
    // clang-format off
    _NV_TARGET_DISPATCH(
      NV_PROVIDES_SM_90,
        (return enable_sm90_simd_reduction<T, ReductionOp, length>() ||
                enable_sm80_simd_reduction<T, ReductionOp, length>() ||
                enable_sm70_simd_reduction<T, ReductionOp, length>();),
      NV_PROVIDES_SM_80,
        (return enable_sm80_simd_reduction<T, ReductionOp, length>() ||
                enable_sm70_simd_reduction<T, ReductionOp, length>();),
      NV_PROVIDES_SM_70,
        (return enable_sm70_simd_reduction<T, ReductionOp, length>();),
      NV_IS_DEVICE,
        (static_cast<void>(length); // maybe unused
         return false;)
    );
    // clang-format on
    return false;
  }
  return false; // nvcc 11.x warning workaround
}

/***********************************************************************************************************************
 * enable_ternary_reduction
 **********************************************************************************************************************/

template <typename T, typename ReductionOp>
struct enable_ternary_reduction_sm90
{
  static constexpr bool value =
    cub::detail::is_one_of<T, ::cuda::std::int32_t, ::cuda::std::uint32_t>()
    && cub::detail::is_one_of<
      ReductionOp,
      ::cuda::minimum<>,
      ::cuda::minimum<T>,
      ::cuda::maximum<>,
      ::cuda::maximum<T>,
      ::cuda::std::plus<>,
      ::cuda::std::plus<T>,
      ::cuda::std::bit_and<>,
      ::cuda::std::bit_and<T>,
      ::cuda::std::bit_or<>,
      ::cuda::std::bit_or<T>,
      ::cuda::std::bit_xor<>,
      ::cuda::std::bit_xor<T>>();
};

#  if defined(_CCCL_HAS_NVFP16)

template <typename ReductionOp>
struct enable_ternary_reduction_sm90<__half2, ReductionOp>
{
  static constexpr bool value =
    cub::detail::is_one_of<ReductionOp,
                           ::cuda::minimum<>,
                           ::cuda::minimum<__half2>,
                           ::cuda::maximum<>,
                           ::cuda::maximum<__half2>,
                           SimdMin<__half>,
                           SimdMax<__half>>();
};

#  endif // defined(_CCCL_HAS_NVFP16)

#  if defined(_CCCL_HAS_NVBF16)

template <typename ReductionOp>
struct enable_ternary_reduction_sm90<__nv_bfloat162, ReductionOp>
{
  static constexpr bool value =
    cub::detail::is_one_of<ReductionOp,
                           ::cuda::minimum<>,
                           ::cuda::minimum<__nv_bfloat162>,
                           ::cuda::maximum<>,
                           ::cuda::maximum<__nv_bfloat162>,
                           SimdMin<__nv_bfloat16>,
                           SimdMax<__nv_bfloat16>>();
};

#  endif // defined(_CCCL_HAS_NVBF16)

template <typename Input, typename ReductionOp, typename AccumT>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE _CCCL_CONSTEXPR_CXX14 bool enable_ternary_reduction()
{
  using cub::detail::is_one_of;
  using ::cuda::std::is_same;
  using T               = ::cuda::std::remove_cvref_t<decltype(::cuda::std::declval<Input>()[0])>;
  constexpr auto length = cub::detail::static_size_v<Input>();
  _CCCL_IF_CONSTEXPR (length < 6)
  {
    return false;
  }
  else
  {
    // clang-format off
    NV_DISPATCH_TARGET(
      NV_PROVIDES_SM_90,
        (return enable_ternary_reduction_sm90<T, ReductionOp>::value;),
      NV_PROVIDES_SM_50,
        (return is_one_of<AccumT, ::cuda::std::int32_t, ::cuda::std::uint32_t>()
             && is_one_of<ReductionOp, ::cuda::std::plus<>,    ::cuda::std::plus<T>,
                                       ::cuda::std::bit_and<>, ::cuda::std::bit_and<T>,
                                       ::cuda::std::bit_or<>,  ::cuda::std::bit_or<T>,
                                       ::cuda::std::bit_xor<>, ::cuda::std::bit_xor<T>>();),
      NV_ANY_TARGET,
        (return false;)
    );
    // clang-format on
  }
  return false; // nvcc 11.x warning workaround
}

template <typename Input, typename ReductionOp, typename AccumT>
_CCCL_NODISCARD _CCCL_DEVICE constexpr bool enable_promotion()
{
  using cub::detail::is_one_of;
  using ::cuda::std::is_same;
  using T = ::cuda::std::remove_cvref_t<decltype(::cuda::std::declval<Input>()[0])>;
  return ::cuda::std::is_integral<T>::value && sizeof(T) <= 2
      && is_one_of<ReductionOp,
                   ::cuda::std::plus<>,
                   ::cuda::std::plus<T>,
                   ::cuda::std::multiplies<>,
                   ::cuda::std::multiplies<T>,
                   ::cuda::std::bit_and<>,
                   ::cuda::std::bit_and<T>,
                   ::cuda::std::bit_or<>,
                   ::cuda::std::bit_or<T>,
                   ::cuda::std::bit_xor<>,
                   ::cuda::std::bit_xor<T>,
                   ::cuda::maximum<>,
                   ::cuda::maximum<T>,
                   ::cuda::minimum<>,
                   ::cuda::minimum<T>>();
}

/***********************************************************************************************************************
 * Internal Reduction Algorithms: Sequential, Binary, Ternary
 **********************************************************************************************************************/

template <typename AccumT, typename Input, typename ReductionOp>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT
ThreadReduceSequential(const Input& input, ReductionOp reduction_op)
{
  AccumT retval = input[0];
#  pragma unroll
  for (int i = 1; i < cub::detail::static_size_v<Input>(); ++i)
  {
    retval = reduction_op(retval, input[i]);
  }
  return retval;
}

template <typename AccumT, typename Input, typename ReductionOp>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT
ThreadReduceBinaryTree(const Input& input, ReductionOp reduction_op)
{
  constexpr auto length = cub::detail::static_size_v<Input>();
  auto array            = cub::detail::to_array<AccumT>(input);
#  pragma unroll
  for (int i = 1; i < length; i *= 2)
  {
#  pragma unroll
    for (int j = 0; j + i < length; j += i * 2)
    {
      array[j] = reduction_op(array[j], array[j + i]);
    }
  }
  return array[0];
}

template <typename AccumT, typename Input, typename ReductionOp>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT
ThreadReduceTernaryTree(const Input& input, ReductionOp reduction_op)
{
  constexpr auto length = cub::detail::static_size_v<Input>();
  auto array            = cub::detail::to_array<AccumT>(input);
#  pragma unroll
  for (int i = 1; i < length; i *= 3)
  {
#  pragma unroll
    for (int j = 0; j + i < length; j += i * 3)
    {
      auto value = reduction_op(array[j], array[j + i]);
      array[j]   = (j + i * 2 < length) ? reduction_op(value, array[j + i * 2]) : value;
    }
  }
  return array[0];
}

/***********************************************************************************************************************
 * SIMD Reduction
 **********************************************************************************************************************/

// never reached. Protect instantion of ThreadReduceSimd with arbitrary types and operators
_CCCL_TEMPLATE(typename Input, typename ReductionOp)
_CCCL_REQUIRES((!cub::internal::enable_generic_simd_reduction<Input, ReductionOp>()))
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE auto ThreadReduceSimd(const Input& input, ReductionOp)
  -> ::cuda::std::remove_cvref_t<decltype(input[0])>
{
  assert(false);
  return input[0];
}

_CCCL_TEMPLATE(typename Input, typename ReductionOp)
_CCCL_REQUIRES((cub::internal::enable_generic_simd_reduction<Input, ReductionOp>()))
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE auto ThreadReduceSimd(const Input& input, ReductionOp reduction_op)
  -> ::cuda::std::remove_cvref_t<decltype(input[0])>
{
  using cub::detail::unsafe_bitcast;
  using T                       = ::cuda::std::remove_cvref_t<decltype(input[0])>;
  using SimdReduceOp            = cub::internal::cub_operator_to_simd_operator_t<ReductionOp, T>;
  using SimdType                = simd_type_t<ReductionOp, T>;
  constexpr auto length         = cub::detail::static_size_v<Input>();
  constexpr auto simd_ratio     = sizeof(SimdType) / sizeof(T);
  constexpr auto length_rounded = (length / simd_ratio) * simd_ratio; // TODO: replace with round_up()
  using UnpackedType            = ::cuda::std::array<T, simd_ratio>;
  using SimdArray               = ::cuda::std::array<SimdType, length / simd_ratio>;
  static_assert(simd_ratio == 1 || simd_ratio == 2, "Only SIMD size <= 2 is supported");
  T local_array[length_rounded];
  UnrolledCopy<length_rounded>(input, local_array);
  auto simd_input         = unsafe_bitcast<SimdArray>(local_array);
  SimdType simd_reduction = cub::ThreadReduce(simd_input, SimdReduceOp{});
  auto unpacked_values    = unsafe_bitcast<UnpackedType>(simd_reduction);
  _CCCL_IF_CONSTEXPR (simd_ratio == 1)
  {
    return unpacked_values[0];
  }
  else // simd_ratio == 2
  {
    // Create a reversed copy of the SIMD reduction result and apply the SIMD operator.
    // This avoids redundant instructions for converting to and from 32-bit register size
    T unpacked_values_rev[] = {unpacked_values[1], unpacked_values[0]};
    auto simd_reduction_rev = unsafe_bitcast<SimdType>(unpacked_values_rev);
    SimdType result         = SimdReduceOp{}(simd_reduction, simd_reduction_rev);
    // repeat the same optimization for the last element
    _CCCL_IF_CONSTEXPR (length % simd_ratio == 1)
    {
      T tail[]       = {input[length - 1], T{}};
      auto tail_simd = unsafe_bitcast<SimdType>(tail);
      result         = SimdReduceOp{}(result, tail_simd);
    }
    return unsafe_bitcast<UnpackedType>(result)[0];
  }
  _CCCL_UNREACHABLE(); // nvcc 11.x warning workaround (never reached)
}

} // namespace internal

/***********************************************************************************************************************
 * Reduction Interface/Dispatch (public)
 **********************************************************************************************************************/

template <typename Input, typename ReductionOp, typename ValueT, typename AccumT>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduce(const Input& input, ReductionOp reduction_op)
{
  static_assert(detail::is_fixed_size_random_access_range_t<Input>::value,
                "Input must support the subscript operator[] and have a compile-time size");
  static_assert(cub::detail::has_binary_call_operator<ReductionOp, ValueT>::value,
                "ReductionOp must have the binary call operator: operator(ValueT, ValueT)");
  using cub::internal::enable_promotion;
  using cub::internal::enable_simd_reduction;
  using cub::internal::enable_ternary_reduction;
  using PromT = ::cuda::std::_If<enable_promotion<Input, ReductionOp, AccumT>(), int, AccumT>;
  _CCCL_IF_CONSTEXPR (!cub::detail::is_one_of<
                        ReductionOp,
                        ::cuda::std::plus<>,
                        ::cuda::std::plus<ValueT>,
                        ::cuda::std::multiplies<>,
                        ::cuda::std::multiplies<ValueT>,
                        ::cuda::std::bit_and<>,
                        ::cuda::std::bit_and<ValueT>,
                        ::cuda::std::bit_or<>,
                        ::cuda::std::bit_or<ValueT>,
                        ::cuda::std::bit_xor<>,
                        ::cuda::std::bit_xor<ValueT>,
                        ::cuda::maximum<>,
                        ::cuda::maximum<ValueT>,
                        ::cuda::minimum<>,
                        ::cuda::minimum<ValueT>,
                        cub::internal::SimdMin<ValueT>,
                        cub::internal::SimdMax<ValueT>>()
                      || sizeof(ValueT) >= 8)
  {
    return cub::internal::ThreadReduceSequential<AccumT>(input, reduction_op);
  }
  _CCCL_IF_CONSTEXPR (cub::detail::is_one_of<ReductionOp, ::cuda::std::plus<>, ::cuda::std::plus<ValueT>>()
                      && cub::detail::is_one_of<ValueT, int, ::cuda::std::uint32_t>())
  {
    // clang-format off
    NV_IF_TARGET(NV_PROVIDES_SM_90,
      (return cub::internal::ThreadReduceSequential<AccumT>(input, reduction_op);),
      (return cub::internal::ThreadReduceTernaryTree<PromT>(input, reduction_op);)
    );
    // clang-format on
  }
  if (enable_simd_reduction<Input, ReductionOp, AccumT>())
  {
    return cub::internal::ThreadReduceSimd(input, reduction_op);
  }
  if (enable_ternary_reduction<Input, ReductionOp, PromT>())
  {
    return cub::internal::ThreadReduceTernaryTree<PromT>(input, reduction_op);
  }
  return cub::internal::ThreadReduceBinaryTree<PromT>(input, reduction_op);
}

template <typename Input,
          typename ReductionOp,
          typename PrefixT,
          typename ValueT = ::cuda::std::remove_cvref_t<decltype(::cuda::std::declval<Input>()[0])>,
          typename AccumT = ::cuda::std::__accumulator_t<ReductionOp, ValueT, PrefixT>>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT
ThreadReduce(const Input& input, ReductionOp reduction_op, PrefixT prefix)
{
  static_assert(detail::is_fixed_size_random_access_range_t<Input>::value,
                "Input must support the subscript operator[] and have a compile-time size");
  static_assert(detail::has_binary_call_operator<ReductionOp, ValueT>::value,
                "ReductionOp must have the binary call operator: operator(ValueT, ValueT)");
  constexpr int length = cub::detail::static_size_v<Input>();
  // copy to a temporary array of type AccumT
  AccumT array[length + 1];
  array[0] = prefix;
#  pragma unroll
  for (int i = 0; i < length; ++i)
  {
    array[i + 1] = input[i];
  }
  return cub::ThreadReduce<decltype(array), ReductionOp, AccumT, AccumT>(array, reduction_op);
}

/***********************************************************************************************************************
 * Pointer Interfaces with explicit Length (internal use only)
 **********************************************************************************************************************/

namespace internal
{

template <int Length, typename T, typename ReductionOp, typename AccumT = ::cuda::std::__accumulator_t<ReductionOp, T>>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduce(const T* input, ReductionOp reduction_op)
{
  static_assert(Length > 0, "Length must be greater than 0");
  static_assert(cub::detail::has_binary_call_operator<ReductionOp, T>::value,
                "ReductionOp must have the binary call operator: operator(V1, V2)");
  using ArrayT = T[Length];
  auto array   = reinterpret_cast<const T(*)[Length]>(input);
  return cub::ThreadReduce(*array, reduction_op);
}

_CCCL_TEMPLATE(int Length,
               typename T,
               typename ReductionOp,
               typename PrefixT,
               typename AccumT = ::cuda::std::__accumulator_t<ReductionOp, T, PrefixT>)
_CCCL_REQUIRES((Length > 0))
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT
ThreadReduce(const T* input, ReductionOp reduction_op, PrefixT prefix)
{
  static_assert(detail::has_binary_call_operator<ReductionOp, T>::value,
                "ReductionOp must have the binary call operator: operator(V1, V2)");
  auto array = reinterpret_cast<const T(*)[Length]>(input);
  return cub::ThreadReduce(*array, reduction_op, prefix);
}

_CCCL_TEMPLATE(int Length, typename T, typename ReductionOp, typename PrefixT)
_CCCL_REQUIRES((Length == 0))
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE T ThreadReduce(const T*, ReductionOp, PrefixT prefix)
{
  return prefix;
}

} // namespace internal

#endif // !_CCCL_DOXYGEN_INVOKED

CUB_NAMESPACE_END