cub/device/device_reduce.cuh

File members: cub/device/device_reduce.cuh

/******************************************************************************
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/detail/choose_offset.cuh>
#include <cub/detail/temporary_storage.cuh>
#include <cub/device/dispatch/dispatch_reduce.cuh>
#include <cub/device/dispatch/dispatch_reduce_by_key.cuh>
#include <cub/device/dispatch/dispatch_reduce_deterministic.cuh>
#include <cub/device/dispatch/dispatch_streaming_reduce.cuh>
#include <cub/util_type.cuh>

#include <thrust/iterator/tabulate_output_iterator.h>

#include <cuda/__execution/determinism.h>
#include <cuda/__execution/require.h>
#include <cuda/__execution/tune.h>
#include <cuda/__memory_resource/get_memory_resource.h>
#include <cuda/__stream/get_stream.h>
#include <cuda/std/__execution/env.h>
#include <cuda/std/limits>

CUB_NAMESPACE_BEGIN

namespace detail
{
namespace reduce
{

struct get_tuning_query_t
{};

template <class Derived>
struct tuning
{
  [[nodiscard]] _CCCL_TRIVIAL_API constexpr auto query(const get_tuning_query_t&) const noexcept -> Derived
  {
    return static_cast<const Derived&>(*this);
  }
};

struct default_tuning : tuning<default_tuning>
{
  template <class AccumT, class Offset, class OpT>
  using fn = policy_hub<AccumT, Offset, OpT>;
};

struct default_rfa_tuning : tuning<default_tuning>
{
  template <class AccumT, class Offset, class OpT>
  using fn = detail::rfa::policy_hub<AccumT, Offset, OpT>;
};

template <typename ExtremumOutIteratorT, typename IndexOutIteratorT>
struct unzip_and_write_arg_extremum_op
{
  ExtremumOutIteratorT result_out_it;
  IndexOutIteratorT index_out_it;

  template <typename IndexT, typename KeyValuePairT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void operator()(IndexT, KeyValuePairT reduced_result)
  {
    *result_out_it = reduced_result.value;
    *index_out_it  = reduced_result.key;
  }
};
} // namespace reduce

// TODO(gevtushenko): move cudax `device_memory_resource` to `cuda::__device_memory_resource` and use it here
struct device_memory_resource
{
  void* allocate(size_t bytes, size_t /* alignment */)
  {
    void* ptr{nullptr};
    _CCCL_TRY_CUDA_API(::cudaMalloc, "allocate failed to allocate with cudaMalloc", &ptr, bytes);
    return ptr;
  }

  void deallocate(void* ptr, size_t /* bytes */)
  {
    _CCCL_ASSERT_CUDA_API(::cudaFree, "deallocate failed", ptr);
  }

  void* allocate_async(size_t bytes, size_t /* alignment */, ::cuda::stream_ref stream)
  {
    return allocate_async(bytes, stream);
  }

  void* allocate_async(size_t bytes, ::cuda::stream_ref stream)
  {
    void* ptr{nullptr};
    _CCCL_TRY_CUDA_API(
      ::cudaMallocAsync, "allocate_async failed to allocate with cudaMallocAsync", &ptr, bytes, stream.get());
    return ptr;
  }

  void deallocate_async(void* ptr, size_t /* bytes */, const ::cuda::stream_ref stream)
  {
    _CCCL_ASSERT_CUDA_API(::cudaFreeAsync, "deallocate_async failed", ptr, stream.get());
  }
};

} // namespace detail

struct DeviceReduce
{
private:
  // TODO(gevtushenko): dispatch to atomic reduce once merged
  template <typename TuningEnvT,
            typename InputIteratorT,
            typename OutputIteratorT,
            typename ReductionOpT,
            typename T,
            typename NumItemsT,
            ::cuda::execution::determinism::__determinism_t Determinism>
  CUB_RUNTIME_FUNCTION static cudaError_t reduce_impl(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    InputIteratorT d_in,
    OutputIteratorT d_out,
    NumItemsT num_items,
    ReductionOpT reduction_op,
    T init,
    ::cuda::execution::determinism::__determinism_holder_t<Determinism>,
    cudaStream_t stream)
  {
    using offset_t        = detail::choose_offset_t<NumItemsT>;
    using accum_t         = ::cuda::std::__accumulator_t<ReductionOpT, detail::it_value_t<InputIteratorT>, T>;
    using transform_t     = ::cuda::std::identity;
    using reduce_tuning_t = ::cuda::std::execution::
      __query_result_or_t<TuningEnvT, detail::reduce::get_tuning_query_t, detail::reduce::default_tuning>;
    using policy_t = typename reduce_tuning_t::template fn<accum_t, offset_t, ReductionOpT>;
    using dispatch_t =
      DispatchReduce<InputIteratorT, OutputIteratorT, offset_t, ReductionOpT, T, accum_t, transform_t, policy_t>;

    return dispatch_t::Dispatch(
      d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast<offset_t>(num_items), reduction_op, init, stream);
  }

  template <typename TuningEnvT,
            typename InputIteratorT,
            typename OutputIteratorT,
            typename ReductionOpT,
            typename T,
            typename NumItemsT>
  CUB_RUNTIME_FUNCTION static cudaError_t reduce_impl(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    InputIteratorT d_in,
    OutputIteratorT d_out,
    NumItemsT num_items,
    ReductionOpT,
    T init,
    ::cuda::execution::determinism::gpu_to_gpu_t,
    cudaStream_t stream)
  {
    using offset_t = detail::choose_offset_t<NumItemsT>;
    using accum_t  = ::cuda::std::__accumulator_t<ReductionOpT, detail::it_value_t<InputIteratorT>, T>;

    // RFA is only supported for float and double accumulators
    constexpr bool is_float_or_double = detail::is_one_of_v<accum_t, float, double>;
    constexpr bool is_sum             = detail::reduce::is_plus<ReductionOpT>::value;
    constexpr bool is_supported       = is_float_or_double && is_sum;

    static_assert(is_supported, "gpu-to-gpu deterministic reduction supports only float and double sum.");

    if constexpr (is_supported)
    {
      using transform_t     = ::cuda::std::identity;
      using reduce_tuning_t = ::cuda::std::execution::
        __query_result_or_t<TuningEnvT, detail::reduce::get_tuning_query_t, detail::reduce::default_rfa_tuning>;
      using policy_t = typename reduce_tuning_t::template fn<accum_t, offset_t, ReductionOpT>;
      using dispatch_t =
        detail::DispatchReduceDeterministic<InputIteratorT, OutputIteratorT, offset_t, T, transform_t, accum_t, policy_t>;

      return dispatch_t::Dispatch(
        d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast<offset_t>(num_items), init, stream);
    }
    else
    {
      return cudaErrorNotSupported;
    }
  }

public:
  template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T, typename NumItemsT>
  CUB_RUNTIME_FUNCTION static cudaError_t Reduce(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    InputIteratorT d_in,
    OutputIteratorT d_out,
    NumItemsT num_items,
    ReductionOpT reduction_op,
    T init,
    cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::Reduce");

    // Signed integer type for global offsets
    using OffsetT = detail::choose_offset_t<NumItemsT>;

    return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, ReductionOpT, T>::Dispatch(
      d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast<OffsetT>(num_items), reduction_op, init, stream);
  }

  template <typename InputIteratorT,
            typename OutputIteratorT,
            typename ReductionOpT,
            typename T,
            typename NumItemsT,
            typename EnvT = ::cuda::std::execution::env<>>
  CUB_RUNTIME_FUNCTION static cudaError_t Reduce(
    InputIteratorT d_in, OutputIteratorT d_out, NumItemsT num_items, ReductionOpT reduction_op, T init, EnvT env = {})
  {
    _CCCL_NVTX_RANGE_SCOPE("cub::DeviceReduce::Reduce");

    static_assert(!_CUDA_STD_EXEC::__queryable_with<EnvT, _CUDA_EXEC::determinism::__get_determinism_t>,
                  "Determinism should be used inside requires to have an effect.");
    using requirements_t =
      _CUDA_STD_EXEC::__query_result_or_t<EnvT, _CUDA_EXEC::__get_requirements_t, _CUDA_STD_EXEC::env<>>;
    using determinism_t =
      _CUDA_STD_EXEC::__query_result_or_t<requirements_t, //
                                          _CUDA_EXEC::determinism::__get_determinism_t,
                                          _CUDA_EXEC::determinism::run_to_run_t>;

    // Query relevant properties from the environment
    auto stream = _CUDA_STD_EXEC::__query_or(env, ::cuda::get_stream, ::cuda::stream_ref{});
    auto mr     = _CUDA_STD_EXEC::__query_or(env, ::cuda::mr::__get_memory_resource, detail::device_memory_resource{});

    void* d_temp_storage      = nullptr;
    size_t temp_storage_bytes = 0;

    using tuning_t = _CUDA_STD_EXEC::__query_result_or_t<EnvT, _CUDA_EXEC::__get_tuning_t, _CUDA_STD_EXEC::env<>>;

    // Query the required temporary storage size
    cudaError_t error = reduce_impl<tuning_t>(
      d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, reduction_op, init, determinism_t{}, stream.get());
    if (error != cudaSuccess)
    {
      return error;
    }

    // TODO(gevtushenko): use uninitialized buffer whenit's available
    error = CubDebug(detail::temporary_storage::allocate_async(d_temp_storage, temp_storage_bytes, mr, stream));
    if (error != cudaSuccess)
    {
      return error;
    }

    // Run the algorithm
    error = reduce_impl<tuning_t>(
      d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, reduction_op, init, determinism_t{}, stream.get());

    // Try to deallocate regardless of the error to avoid memory leaks
    cudaError_t deallocate_error =
      CubDebug(detail::temporary_storage::deallocate_async(d_temp_storage, temp_storage_bytes, mr, stream));

    if (error != cudaSuccess)
    {
      // Reduction error takes precedence over deallocation error since it happens first
      return error;
    }

    return deallocate_error;
  }

  template <typename InputIteratorT, typename OutputIteratorT, typename NumItemsT>
  CUB_RUNTIME_FUNCTION static cudaError_t
  Sum(void* d_temp_storage,
      size_t& temp_storage_bytes,
      InputIteratorT d_in,
      OutputIteratorT d_out,
      NumItemsT num_items,
      cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::Sum");

    // Signed integer type for global offsets
    using OffsetT = detail::choose_offset_t<NumItemsT>;

    // The output value type
    using OutputT = cub::detail::non_void_value_t<OutputIteratorT, cub::detail::it_value_t<InputIteratorT>>;

    using InitT = OutputT;

    return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, ::cuda::std::plus<>, InitT>::Dispatch(
      d_temp_storage,
      temp_storage_bytes,
      d_in,
      d_out,
      static_cast<OffsetT>(num_items),
      ::cuda::std::plus<>{},
      InitT{}, // zero-initialize
      stream);
  }

  template <typename InputIteratorT, typename OutputIteratorT, typename NumItemsT>
  CUB_RUNTIME_FUNCTION static cudaError_t
  Min(void* d_temp_storage,
      size_t& temp_storage_bytes,
      InputIteratorT d_in,
      OutputIteratorT d_out,
      NumItemsT num_items,
      cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::Min");

    // Signed integer type for global offsets
    using OffsetT = detail::choose_offset_t<NumItemsT>;

    // The input value type
    using InputT = cub::detail::it_value_t<InputIteratorT>;

    using InitT = InputT;

    return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, ::cuda::minimum<>, InitT>::Dispatch(
      d_temp_storage,
      temp_storage_bytes,
      d_in,
      d_out,
      static_cast<OffsetT>(num_items),
      ::cuda::minimum<>{},
      ::cuda::std::numeric_limits<InitT>::max(),
      stream);
  }

  template <typename InputIteratorT, typename ExtremumOutIteratorT, typename IndexOutIteratorT>
  CUB_RUNTIME_FUNCTION static cudaError_t ArgMin(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    InputIteratorT d_in,
    ExtremumOutIteratorT d_min_out,
    IndexOutIteratorT d_index_out,
    ::cuda::std::int64_t num_items,
    cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::ArgMin");

    // The input type
    using InputValueT = cub::detail::it_value_t<InputIteratorT>;

    // Offset type used within the kernel and to index within one partition
    using PerPartitionOffsetT = int;

    // Offset type used to index within the total input in the range [d_in, d_in + num_items)
    using GlobalOffsetT = ::cuda::std::int64_t;

    // The value type used for the extremum
    using OutputExtremumT = detail::non_void_value_t<ExtremumOutIteratorT, InputValueT>;
    using InitT           = OutputExtremumT;

    // Reduction operation
    using ReduceOpT = cub::ArgMin;

    // Initial value
    OutputExtremumT initial_value{::cuda::std::numeric_limits<InputValueT>::max()};

    // Tabulate output iterator that unzips the result and writes it to the user-provided output iterators
    auto out_it = THRUST_NS_QUALIFIER::make_tabulate_output_iterator(
      detail::reduce::unzip_and_write_arg_extremum_op<ExtremumOutIteratorT, IndexOutIteratorT>{d_min_out, d_index_out});

    return detail::reduce::dispatch_streaming_arg_reduce_t<
      InputIteratorT,
      decltype(out_it),
      PerPartitionOffsetT,
      GlobalOffsetT,
      ReduceOpT,
      InitT>::Dispatch(d_temp_storage,
                       temp_storage_bytes,
                       d_in,
                       out_it,
                       static_cast<GlobalOffsetT>(num_items),
                       ReduceOpT{},
                       initial_value,
                       stream);
  }

  template <typename InputIteratorT, typename OutputIteratorT>
  CCCL_DEPRECATED_BECAUSE("CUB has superseded this interface in favor of the ArgMin interface that takes two separate "
                          "iterators: one iterator to which the extremum is written and another iterator to which the "
                          "index of the found extremum is written. ")
  CUB_RUNTIME_FUNCTION static cudaError_t
    ArgMin(void* d_temp_storage,
           size_t& temp_storage_bytes,
           InputIteratorT d_in,
           OutputIteratorT d_out,
           int num_items,
           cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::ArgMin");

    // Signed integer type for global offsets
    using OffsetT = int;

    // The input type
    using InputValueT = cub::detail::it_value_t<InputIteratorT>;

    // The output tuple type
    using OutputTupleT = cub::detail::non_void_value_t<OutputIteratorT, KeyValuePair<OffsetT, InputValueT>>;

    using AccumT = OutputTupleT;

    using InitT = detail::reduce::empty_problem_init_t<AccumT>;

    // The output value type
    using OutputValueT = typename OutputTupleT::Value;

    // Wrapped input iterator to produce index-value <OffsetT, InputT> tuples
    using ArgIndexInputIteratorT = ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;

    ArgIndexInputIteratorT d_indexed_in(d_in);

    // Initial value
    InitT initial_value{AccumT(1, ::cuda::std::numeric_limits<InputValueT>::max())};

    return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMin, InitT, AccumT>::Dispatch(
      d_temp_storage, temp_storage_bytes, d_indexed_in, d_out, num_items, cub::ArgMin(), initial_value, stream);
  }

  template <typename InputIteratorT, typename OutputIteratorT, typename NumItemsT>
  CUB_RUNTIME_FUNCTION static cudaError_t
  Max(void* d_temp_storage,
      size_t& temp_storage_bytes,
      InputIteratorT d_in,
      OutputIteratorT d_out,
      NumItemsT num_items,
      cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::Max");

    // Signed integer type for global offsets
    using OffsetT = detail::choose_offset_t<NumItemsT>;

    // The input value type
    using InputT = cub::detail::it_value_t<InputIteratorT>;

    using InitT = InputT;

    return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, ::cuda::maximum<>, InitT>::Dispatch(
      d_temp_storage,
      temp_storage_bytes,
      d_in,
      d_out,
      static_cast<OffsetT>(num_items),
      ::cuda::maximum<>{},
      ::cuda::std::numeric_limits<InitT>::lowest(),
      stream);
  }

  template <typename InputIteratorT, typename ExtremumOutIteratorT, typename IndexOutIteratorT>
  CUB_RUNTIME_FUNCTION static cudaError_t ArgMax(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    InputIteratorT d_in,
    ExtremumOutIteratorT d_max_out,
    IndexOutIteratorT d_index_out,
    ::cuda::std::int64_t num_items,
    cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::ArgMax");

    // The input type
    using InputValueT = cub::detail::it_value_t<InputIteratorT>;

    // Offset type used within the kernel and to index within one partition
    using PerPartitionOffsetT = int;

    // Offset type used to index within the total input in the range [d_in, d_in + num_items)
    using GlobalOffsetT = ::cuda::std::int64_t;

    // The value type used for the extremum
    using OutputExtremumT = detail::non_void_value_t<ExtremumOutIteratorT, InputValueT>;
    using InitT           = OutputExtremumT;

    // Reduction operation
    using ReduceOpT = cub::ArgMax;

    // Initial value
    OutputExtremumT initial_value{::cuda::std::numeric_limits<InputValueT>::lowest()};

    // Tabulate output iterator that unzips the result and writes it to the user-provided output iterators
    auto out_it = THRUST_NS_QUALIFIER::make_tabulate_output_iterator(
      detail::reduce::unzip_and_write_arg_extremum_op<ExtremumOutIteratorT, IndexOutIteratorT>{d_max_out, d_index_out});

    return detail::reduce::dispatch_streaming_arg_reduce_t<
      InputIteratorT,
      decltype(out_it),
      PerPartitionOffsetT,
      GlobalOffsetT,
      ReduceOpT,
      InitT>::Dispatch(d_temp_storage,
                       temp_storage_bytes,
                       d_in,
                       out_it,
                       static_cast<GlobalOffsetT>(num_items),
                       ReduceOpT{},
                       initial_value,
                       stream);
  }

  template <typename InputIteratorT, typename OutputIteratorT>
  CCCL_DEPRECATED_BECAUSE("CUB has superseded this interface in favor of the ArgMax interface that takes two separate "
                          "iterators: one iterator to which the extremum is written and another iterator to which the "
                          "index of the found extremum is written. ")
  CUB_RUNTIME_FUNCTION static cudaError_t
    ArgMax(void* d_temp_storage,
           size_t& temp_storage_bytes,
           InputIteratorT d_in,
           OutputIteratorT d_out,
           int num_items,
           cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::ArgMax");

    // Signed integer type for global offsets
    using OffsetT = int;

    // The input type
    using InputValueT = cub::detail::it_value_t<InputIteratorT>;

    // The output tuple type
    using OutputTupleT = cub::detail::non_void_value_t<OutputIteratorT, KeyValuePair<OffsetT, InputValueT>>;

    using AccumT = OutputTupleT;

    // The output value type
    using OutputValueT = typename OutputTupleT::Value;

    using InitT = detail::reduce::empty_problem_init_t<AccumT>;

    // Wrapped input iterator to produce index-value <OffsetT, InputT> tuples
    using ArgIndexInputIteratorT = ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;

    ArgIndexInputIteratorT d_indexed_in(d_in);

    // Initial value
    InitT initial_value{AccumT(1, ::cuda::std::numeric_limits<InputValueT>::lowest())};

    return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMax, InitT, AccumT>::Dispatch(
      d_temp_storage, temp_storage_bytes, d_indexed_in, d_out, num_items, cub::ArgMax(), initial_value, stream);
  }

  template <typename InputIteratorT,
            typename OutputIteratorT,
            typename ReductionOpT,
            typename TransformOpT,
            typename T,
            typename NumItemsT>
  CUB_RUNTIME_FUNCTION static cudaError_t TransformReduce(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    InputIteratorT d_in,
    OutputIteratorT d_out,
    NumItemsT num_items,
    ReductionOpT reduction_op,
    TransformOpT transform_op,
    T init,
    cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::TransformReduce");

    using OffsetT = detail::choose_offset_t<NumItemsT>;

    return DispatchTransformReduce<InputIteratorT, OutputIteratorT, OffsetT, ReductionOpT, TransformOpT, T>::Dispatch(
      d_temp_storage,
      temp_storage_bytes,
      d_in,
      d_out,
      static_cast<OffsetT>(num_items),
      reduction_op,
      init,
      stream,
      transform_op);
  }

  template <typename KeysInputIteratorT,
            typename UniqueOutputIteratorT,
            typename ValuesInputIteratorT,
            typename AggregatesOutputIteratorT,
            typename NumRunsOutputIteratorT,
            typename ReductionOpT,
            typename NumItemsT>
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t ReduceByKey(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    KeysInputIteratorT d_keys_in,
    UniqueOutputIteratorT d_unique_out,
    ValuesInputIteratorT d_values_in,
    AggregatesOutputIteratorT d_aggregates_out,
    NumRunsOutputIteratorT d_num_runs_out,
    ReductionOpT reduction_op,
    NumItemsT num_items,
    cudaStream_t stream = 0)
  {
    _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceReduce::ReduceByKey");

    // Signed integer type for global offsets
    using OffsetT = detail::choose_offset_t<NumItemsT>;

    // FlagT iterator type (not used)

    // Selection op (not used)

    // Default == operator
    using EqualityOp = ::cuda::std::equal_to<>;

    return DispatchReduceByKey<
      KeysInputIteratorT,
      UniqueOutputIteratorT,
      ValuesInputIteratorT,
      AggregatesOutputIteratorT,
      NumRunsOutputIteratorT,
      EqualityOp,
      ReductionOpT,
      OffsetT>::Dispatch(d_temp_storage,
                         temp_storage_bytes,
                         d_keys_in,
                         d_unique_out,
                         d_values_in,
                         d_aggregates_out,
                         d_num_runs_out,
                         EqualityOp(),
                         reduction_op,
                         static_cast<OffsetT>(num_items),
                         stream);
  }
};

CUB_NAMESPACE_END