/home/runner/work/cccl/cccl/cub/cub/util_type.cuh

File members: /home/runner/work/cccl/cccl/cub/cub/util_type.cuh

/******************************************************************************
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/detail/uninitialized_copy.cuh>

#include <cuda/std/cstdint>
#include <cuda/std/limits>
#include <cuda/std/type_traits>

#if defined(_CCCL_HAS_NVBF16)
#  if !defined(_CCCL_CUDACC_BELOW_11_8)
// cuda_fp8.h resets default for C4127, so we have to guard the inclusion
_CCCL_DIAG_PUSH
#    include <cuda_fp8.h>
_CCCL_DIAG_POP
#  endif // !_CCCL_CUDACC_BELOW_11_8
#endif // _CCCL_HAS_NV_BF16

#if !defined(_CCCL_COMPILER_NVRTC)
#  include <iterator>
#else
#  include <cuda/std/iterator>
#endif

CUB_NAMESPACE_BEGIN

#ifndef CUB_IS_INT128_ENABLED
#  if defined(__CUDACC_RTC__)
#    if defined(__CUDACC_RTC_INT128__)
#      define CUB_IS_INT128_ENABLED 1
#    endif // !defined(__CUDACC_RTC_INT128__)
#  else // !defined(__CUDACC_RTC__)
#    if _CCCL_CUDACC_VER >= 1105000
#      if defined(_CCCL_COMPILER_GCC) || defined(_CCCL_COMPILER_CLANG) || defined(_CCCL_COMPILER_ICC) \
        || defined(_CCCL_COMPILER_NVHPC)
#        define CUB_IS_INT128_ENABLED 1
#      endif // GCC || CLANG || ICC || NVHPC
#    endif // CTK >= 11.5
#  endif // !defined(__CUDACC_RTC__)
#endif // !defined(CUB_IS_INT128_ENABLED)

/******************************************************************************
 * Conditional types
 ******************************************************************************/

#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
namespace detail
{

template <bool Test, class T1, class T2>
using conditional_t = typename ::cuda::std::conditional<Test, T1, T2>::type;

template <typename Iterator>
using value_t =
#  if !defined(_CCCL_COMPILER_NVRTC)
  typename std::iterator_traits<Iterator>::value_type;
#  else // defined(_CCCL_COMPILER_NVRTC)
  typename ::cuda::std::iterator_traits<Iterator>::value_type;
#  endif // defined(_CCCL_COMPILER_NVRTC)

template <typename It,
          typename FallbackT,
          bool = ::cuda::std::
            is_same<typename ::cuda::std::remove_cv<typename ::cuda::std::remove_pointer<It>::type>::type, void>::value>
struct non_void_value_impl
{
  using type = FallbackT;
};

template <typename It, typename FallbackT>
struct non_void_value_impl<It, FallbackT, false>
{
  using type =
    typename ::cuda::std::conditional<::cuda::std::is_same<value_t<It>, void>::value, FallbackT, value_t<It>>::type;
};

template <typename It, typename FallbackT>
using non_void_value_t = typename non_void_value_impl<It, FallbackT>::type;
} // namespace detail

/******************************************************************************
 * Static math
 ******************************************************************************/

template <int N, int CURRENT_VAL = N, int COUNT = 0>
struct Log2
{
  enum
  {
    VALUE = Log2<N, (CURRENT_VAL >> 1), COUNT + 1>::VALUE
  }; // Inductive case
};

#  ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

template <int N, int COUNT>
struct Log2<N, 0, COUNT>
{
  enum
  {
    VALUE = (1 << (COUNT - 1) < N) ? // Base case
              COUNT
                                   : COUNT - 1
  };
};

#  endif // DOXYGEN_SHOULD_SKIP_THIS

template <int N>
struct PowerOfTwo
{
  enum
  {
    VALUE = ((N & (N - 1)) == 0)
  };
};

#endif // DOXYGEN_SHOULD_SKIP_THIS

/******************************************************************************
 * Marker types
 ******************************************************************************/

#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

struct NullType
{
  using value_type = NullType;

  NullType() = default;

  template <typename T>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE explicit NullType(const T&)
  {}

  template <typename T>
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE NullType& operator=(const T&)
  {
    return *this;
  }

  friend _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator==(const NullType&, const NullType&)
  {
    return true;
  }

  friend _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator!=(const NullType&, const NullType&)
  {
    return false;
  }
};

template <int A>
struct Int2Type
{
  enum
  {
    VALUE = A
  };
};

template <typename T, typename IterT = T*>
struct FutureValue
{
  using value_type    = T;
  using iterator_type = IterT;
  explicit _CCCL_HOST_DEVICE _CCCL_FORCEINLINE FutureValue(IterT iter)
      : m_iter(iter)
  {}
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE operator T()
  {
    return *m_iter;
  }

private:
  IterT m_iter;
};

namespace detail
{

template <typename T, typename IterT = T*>
struct InputValue
{
  using value_type    = T;
  using iterator_type = IterT;
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE operator T()
  {
    if (m_is_future)
    {
      return m_future_value;
    }
    return m_immediate_value;
  }
  explicit _CCCL_HOST_DEVICE _CCCL_FORCEINLINE InputValue(T immediate_value)
      : m_is_future(false)
      , m_immediate_value(immediate_value)
  {}
  explicit _CCCL_HOST_DEVICE _CCCL_FORCEINLINE InputValue(FutureValue<T, IterT> future_value)
      : m_is_future(true)
      , m_future_value(future_value)
  {}
  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE InputValue(const InputValue& other)
      : m_is_future(other.m_is_future)
  {
    if (m_is_future)
    {
      m_future_value = other.m_future_value;
    }
    else
    {
      detail::uninitialized_copy_single(&m_immediate_value, other.m_immediate_value);
    }
  }

private:
  bool m_is_future;
  union
  {
    FutureValue<T, IterT> m_future_value;
    T m_immediate_value;
  };
};

} // namespace detail

/******************************************************************************
 * Size and alignment
 ******************************************************************************/

template <typename T>
struct AlignBytes
{
  struct Pad
  {
    T val;
    char byte;
  };

  enum
  {
    ALIGN_BYTES = sizeof(Pad) - sizeof(T)
  };

  using Type = T;
};

// Specializations where host C++ compilers (e.g., 32-bit Windows) may disagree
// with device C++ compilers (EDG) on types passed as template parameters through
// kernel functions

#  define __CUB_ALIGN_BYTES(t, b)                                                                  \
    template <>                                                                                    \
    struct AlignBytes<t>                                                                           \
    {                                                                                              \
      enum                                                                                         \
      {                                                                                            \
        ALIGN_BYTES = b                                                                            \
      };                                                                                           \
      typedef __align__(b) t Type;                                                                 \
      /* TODO(bgruber): rewriting the above to using Type __align__(b) = t; does not compile :S */ \
    };

__CUB_ALIGN_BYTES(short4, 8)
__CUB_ALIGN_BYTES(ushort4, 8)
__CUB_ALIGN_BYTES(int2, 8)
__CUB_ALIGN_BYTES(uint2, 8)
__CUB_ALIGN_BYTES(long long, 8)
__CUB_ALIGN_BYTES(unsigned long long, 8)
__CUB_ALIGN_BYTES(float2, 8)
__CUB_ALIGN_BYTES(double, 8)
#  ifdef _WIN32
__CUB_ALIGN_BYTES(long2, 8)
__CUB_ALIGN_BYTES(ulong2, 8)
#  else
__CUB_ALIGN_BYTES(long2, 16)
__CUB_ALIGN_BYTES(ulong2, 16)
#  endif
__CUB_ALIGN_BYTES(int4, 16)
__CUB_ALIGN_BYTES(uint4, 16)
__CUB_ALIGN_BYTES(float4, 16)
__CUB_ALIGN_BYTES(long4, 16)
__CUB_ALIGN_BYTES(ulong4, 16)
__CUB_ALIGN_BYTES(longlong2, 16)
__CUB_ALIGN_BYTES(ulonglong2, 16)
__CUB_ALIGN_BYTES(double2, 16)
__CUB_ALIGN_BYTES(longlong4, 16)
__CUB_ALIGN_BYTES(ulonglong4, 16)
__CUB_ALIGN_BYTES(double4, 16)

// clang-format off
template <typename T> struct AlignBytes<volatile T> : AlignBytes<T> {};
template <typename T> struct AlignBytes<const T> : AlignBytes<T> {};
template <typename T> struct AlignBytes<const volatile T> : AlignBytes<T> {};
// clang-format on

template <typename T>
struct UnitWord
{
  enum
  {
    ALIGN_BYTES = AlignBytes<T>::ALIGN_BYTES
  };

  template <typename Unit>
  struct IsMultiple
  {
    enum
    {
      UNIT_ALIGN_BYTES = AlignBytes<Unit>::ALIGN_BYTES,
      IS_MULTIPLE      = (sizeof(T) % sizeof(Unit) == 0) && (int(ALIGN_BYTES) % int(UNIT_ALIGN_BYTES) == 0)
    };
  };

  using ShuffleWord = cub::detail::conditional_t<
    IsMultiple<int>::IS_MULTIPLE,
    unsigned int,
    cub::detail::conditional_t<IsMultiple<short>::IS_MULTIPLE, unsigned short, unsigned char>>;

  using VolatileWord = cub::detail::conditional_t<IsMultiple<long long>::IS_MULTIPLE, unsigned long long, ShuffleWord>;

  using DeviceWord = cub::detail::conditional_t<IsMultiple<longlong2>::IS_MULTIPLE, ulonglong2, VolatileWord>;

  using TextureWord =
    cub::detail::conditional_t<IsMultiple<int4>::IS_MULTIPLE,
                               uint4,
                               cub::detail::conditional_t<IsMultiple<int2>::IS_MULTIPLE, uint2, ShuffleWord>>;
};

// float2 specialization workaround (for SM10-SM13)
template <>
struct UnitWord<float2>
{
  using ShuffleWord  = int;
  using VolatileWord = unsigned long long;
  using DeviceWord   = unsigned long long;
  using TextureWord  = float2;
};

// float4 specialization workaround (for SM10-SM13)
template <>
struct UnitWord<float4>
{
  using ShuffleWord  = int;
  using VolatileWord = unsigned long long;
  using DeviceWord   = ulonglong2;
  using TextureWord  = float4;
};

// char2 specialization workaround (for SM10-SM13)
template <>
struct UnitWord<char2>
{
  using ShuffleWord  = unsigned short;
  using VolatileWord = unsigned short;
  using DeviceWord   = unsigned short;
  using TextureWord  = unsigned short;
};

// clang-format off
template <typename T> struct UnitWord<volatile T> : UnitWord<T> {};
template <typename T> struct UnitWord<const T> : UnitWord<T> {};
template <typename T> struct UnitWord<const volatile T> : UnitWord<T> {};
// clang-format on

/******************************************************************************
 * Vector type inference utilities.
 ******************************************************************************/

template <typename T, int vec_elements>
struct CubVector
{
  static_assert(!sizeof(T), "CubVector can only have 1-4 elements");
};

enum
{
  MAX_VEC_ELEMENTS = 4,
};

template <typename T>
struct CubVector<T, 1>
{
  T x;

  using BaseType = T;
  using Type     = CubVector<T, 1>;
};

template <typename T>
struct CubVector<T, 2>
{
  T x;
  T y;

  using BaseType = T;
  using Type     = CubVector<T, 2>;
};

template <typename T>
struct CubVector<T, 3>
{
  T x;
  T y;
  T z;

  using BaseType = T;
  using Type     = CubVector<T, 3>;
};

template <typename T>
struct CubVector<T, 4>
{
  T x;
  T y;
  T z;
  T w;

  using BaseType = T;
  using Type     = CubVector<T, 4>;
};

#  define CUB_DEFINE_VECTOR_TYPE(base_type, short_type)                                     \
                                                                                            \
    template <>                                                                             \
    struct CubVector<base_type, 1> : short_type##1                                          \
    {                                                                                       \
      using BaseType = base_type;                                                           \
      using Type     = short_type##1;                                                       \
      _CCCL_HOST_DEVICE _CCCL_FORCEINLINE CubVector operator+(const CubVector& other) const \
      {                                                                                     \
        CubVector retval;                                                                   \
        retval.x = x + other.x;                                                             \
        return retval;                                                                      \
      }                                                                                     \
      _CCCL_HOST_DEVICE _CCCL_FORCEINLINE CubVector operator-(const CubVector& other) const \
      {                                                                                     \
        CubVector retval;                                                                   \
        retval.x = x - other.x;                                                             \
        return retval;                                                                      \
      }                                                                                     \
    };                                                                                      \
                                                                                            \
    template <>                                                                             \
    struct CubVector<base_type, 2> : short_type##2                                          \
    {                                                                                       \
      using BaseType = base_type;                                                           \
      using Type     = short_type##2;                                                       \
      _CCCL_HOST_DEVICE _CCCL_FORCEINLINE CubVector operator+(const CubVector& other) const \
      {                                                                                     \
        CubVector retval;                                                                   \
        retval.x = x + other.x;                                                             \
        retval.y = y + other.y;                                                             \
        return retval;                                                                      \
      }                                                                                     \
      _CCCL_HOST_DEVICE _CCCL_FORCEINLINE CubVector operator-(const CubVector& other) const \
      {                                                                                     \
        CubVector retval;                                                                   \
        retval.x = x - other.x;                                                             \
        retval.y = y - other.y;                                                             \
        return retval;                                                                      \
      }                                                                                     \
    };                                                                                      \
                                                                                            \
    template <>                                                                             \
    struct CubVector<base_type, 3> : short_type##3                                          \
    {                                                                                       \
      using BaseType = base_type;                                                           \
      using Type     = short_type##3;                                                       \
      _CCCL_HOST_DEVICE _CCCL_FORCEINLINE CubVector operator+(const CubVector& other) const \
      {                                                                                     \
        CubVector retval;                                                                   \
        retval.x = x + other.x;                                                             \
        retval.y = y + other.y;                                                             \
        retval.z = z + other.z;                                                             \
        return retval;                                                                      \
      }                                                                                     \
      _CCCL_HOST_DEVICE _CCCL_FORCEINLINE CubVector operator-(const CubVector& other) const \
      {                                                                                     \
        CubVector retval;                                                                   \
        retval.x = x - other.x;                                                             \
        retval.y = y - other.y;                                                             \
        retval.z = z - other.z;                                                             \
        return retval;                                                                      \
      }                                                                                     \
    };                                                                                      \
                                                                                            \
    template <>                                                                             \
    struct CubVector<base_type, 4> : short_type##4                                          \
    {                                                                                       \
      using BaseType = base_type;                                                           \
      using Type     = short_type##4;                                                       \
      _CCCL_HOST_DEVICE _CCCL_FORCEINLINE CubVector operator+(const CubVector& other) const \
      {                                                                                     \
        CubVector retval;                                                                   \
        retval.x = x + other.x;                                                             \
        retval.y = y + other.y;                                                             \
        retval.z = z + other.z;                                                             \
        retval.w = w + other.w;                                                             \
        return retval;                                                                      \
      }                                                                                     \
      _CCCL_HOST_DEVICE _CCCL_FORCEINLINE CubVector operator-(const CubVector& other) const \
      {                                                                                     \
        CubVector retval;                                                                   \
        retval.x = x - other.x;                                                             \
        retval.y = y - other.y;                                                             \
        retval.z = z - other.z;                                                             \
        retval.w = w - other.w;                                                             \
        return retval;                                                                      \
      }                                                                                     \
    };

// Expand CUDA vector types for built-in primitives
// clang-format off
CUB_DEFINE_VECTOR_TYPE(char,               char)
CUB_DEFINE_VECTOR_TYPE(signed char,        char)
CUB_DEFINE_VECTOR_TYPE(short,              short)
CUB_DEFINE_VECTOR_TYPE(int,                int)
CUB_DEFINE_VECTOR_TYPE(long,               long)
CUB_DEFINE_VECTOR_TYPE(long long,          longlong)
CUB_DEFINE_VECTOR_TYPE(unsigned char,      uchar)
CUB_DEFINE_VECTOR_TYPE(unsigned short,     ushort)
CUB_DEFINE_VECTOR_TYPE(unsigned int,       uint)
CUB_DEFINE_VECTOR_TYPE(unsigned long,      ulong)
CUB_DEFINE_VECTOR_TYPE(unsigned long long, ulonglong)
CUB_DEFINE_VECTOR_TYPE(float,              float)
CUB_DEFINE_VECTOR_TYPE(double,             double)
CUB_DEFINE_VECTOR_TYPE(bool,               uchar)
// clang-format on

// Undefine macros
#  undef CUB_DEFINE_VECTOR_TYPE

/******************************************************************************
 * Wrapper types
 ******************************************************************************/

template <typename T>
struct Uninitialized
{
  using DeviceWord = typename UnitWord<T>::DeviceWord;

  static constexpr ::cuda::std::size_t DATA_SIZE = sizeof(T);
  static constexpr ::cuda::std::size_t WORD_SIZE = sizeof(DeviceWord);
  static constexpr ::cuda::std::size_t WORDS     = DATA_SIZE / WORD_SIZE;

  DeviceWord storage[WORDS];

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T& Alias()
  {
    return reinterpret_cast<T&>(*this);
  }
};

template <typename _Key,
          typename _Value
#  if defined(_WIN32) && !defined(_WIN64)
          ,
          bool KeyIsLT = (AlignBytes<_Key>::ALIGN_BYTES < AlignBytes<_Value>::ALIGN_BYTES),
          bool ValIsLT = (AlignBytes<_Value>::ALIGN_BYTES < AlignBytes<_Key>::ALIGN_BYTES)
#  endif // #if defined(_WIN32) && !defined(_WIN64)
          >
struct KeyValuePair
{
  using Key   = _Key;
  using Value = _Value;

  Key key;
  Value value;

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePair() {}

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePair(Key const& key, Value const& value)
      : key(key)
      , value(value)
  {}

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator!=(const KeyValuePair& b)
  {
    return (value != b.value) || (key != b.key);
  }
};

#  if defined(_WIN32) && !defined(_WIN64)

template <typename K, typename V>
struct KeyValuePair<K, V, true, false>
{
  using Key   = K;
  using Value = V;

  using Pad = char[AlignBytes<V>::ALIGN_BYTES - AlignBytes<K>::ALIGN_BYTES];

  Value value; // Value has larger would-be alignment and goes first
  Key key;
  Pad pad;

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePair() {}

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePair(Key const& key, Value const& value)
      : key(key)
      , value(value)
  {}

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator!=(const KeyValuePair& b)
  {
    return (value != b.value) || (key != b.key);
  }
};

template <typename K, typename V>
struct KeyValuePair<K, V, false, true>
{
  using Key   = K;
  using Value = V;

  using Pad = char[AlignBytes<K>::ALIGN_BYTES - AlignBytes<V>::ALIGN_BYTES];

  Key key; // Key has larger would-be alignment and goes first
  Value value;
  Pad pad;

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePair() {}

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePair(Key const& key, Value const& value)
      : key(key)
      , value(value)
  {}

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator!=(const KeyValuePair& b)
  {
    return (value != b.value) || (key != b.key);
  }
};

#  endif // #if defined(_WIN32) && !defined(_WIN64)

template <typename T, int COUNT>
struct CUB_DEPRECATED_BECAUSE("Use cuda::std::array instead.") ArrayWrapper
{
  T array[COUNT];

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE ArrayWrapper() {}
};

template <typename T>
struct DoubleBuffer
{
  T* d_buffers[2];

  int selector;

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE DoubleBuffer()
  {
    selector     = 0;
    d_buffers[0] = nullptr;
    d_buffers[1] = nullptr;
  }

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE DoubleBuffer(T* d_current,
                                                   T* d_alternate)
  {
    selector     = 0;
    d_buffers[0] = d_current;
    d_buffers[1] = d_alternate;
  }

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T* Current()
  {
    return d_buffers[selector];
  }

  _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T* Alternate()
  {
    return d_buffers[selector ^ 1];
  }
};

/******************************************************************************
 * Typedef-detection
 ******************************************************************************/

#  define CUB_DEFINE_DETECT_NESTED_TYPE(detector_name, nested_type_name) \
    template <typename T>                                                \
    struct detector_name                                                 \
    {                                                                    \
      template <typename C>                                              \
      static char& test(typename C::nested_type_name*);                  \
      template <typename>                                                \
      static int& test(...);                                             \
      enum                                                               \
      {                                                                  \
        VALUE = sizeof(test<T>(0)) < sizeof(int)                         \
      };                                                                 \
    };

/******************************************************************************
 * Typedef-detection
 ******************************************************************************/

template <typename T, typename BinaryOp>
struct BinaryOpHasIdxParam
{
private:
  /*
      template <typename BinaryOpT, bool (BinaryOpT::*)(const T &a, const T &b, unsigned int idx) const>  struct SFINAE1
     {}; template <typename BinaryOpT, bool (BinaryOpT::*)(const T &a, const T &b, unsigned int idx)>        struct
     SFINAE2 {}; template <typename BinaryOpT, bool (BinaryOpT::*)(T a, T b, unsigned int idx) const> struct SFINAE3 {};
      template <typename BinaryOpT, bool (BinaryOpT::*)(T a, T b, unsigned int idx)>                      struct SFINAE4
     {};
  */
  template <typename BinaryOpT, bool (BinaryOpT::*)(const T& a, const T& b, int idx) const>
  struct SFINAE5
  {};
  template <typename BinaryOpT, bool (BinaryOpT::*)(const T& a, const T& b, int idx)>
  struct SFINAE6
  {};
  template <typename BinaryOpT, bool (BinaryOpT::*)(T a, T b, int idx) const>
  struct SFINAE7
  {};
  template <typename BinaryOpT, bool (BinaryOpT::*)(T a, T b, int idx)>
  struct SFINAE8
  {};
  /*
      template <typename BinaryOpT> static char Test(SFINAE1<BinaryOpT, &BinaryOpT::operator()> *);
      template <typename BinaryOpT> static char Test(SFINAE2<BinaryOpT, &BinaryOpT::operator()> *);
      template <typename BinaryOpT> static char Test(SFINAE3<BinaryOpT, &BinaryOpT::operator()> *);
      template <typename BinaryOpT> static char Test(SFINAE4<BinaryOpT, &BinaryOpT::operator()> *);
  */
  template <typename BinaryOpT>
  _CCCL_HOST_DEVICE static char Test(SFINAE5<BinaryOpT, &BinaryOpT::operator()>*);
  template <typename BinaryOpT>
  _CCCL_HOST_DEVICE static char Test(SFINAE6<BinaryOpT, &BinaryOpT::operator()>*);
  template <typename BinaryOpT>
  _CCCL_HOST_DEVICE static char Test(SFINAE7<BinaryOpT, &BinaryOpT::operator()>*);
  template <typename BinaryOpT>
  _CCCL_HOST_DEVICE static char Test(SFINAE8<BinaryOpT, &BinaryOpT::operator()>*);

  template <typename BinaryOpT>
  _CCCL_HOST_DEVICE static int Test(...);

public:
  static constexpr bool HAS_PARAM = sizeof(Test<BinaryOp>(nullptr)) == sizeof(char);
};

/******************************************************************************
 * Simple type traits utilities.
 *
 * For example:
 *     Traits<int>::CATEGORY             // SIGNED_INTEGER
 *     Traits<NullType>::NULL_TYPE       // true
 *     Traits<uint4>::CATEGORY           // NOT_A_NUMBER
 *     Traits<uint4>::PRIMITIVE;         // false
 *
 ******************************************************************************/

enum Category
{
  NOT_A_NUMBER,
  SIGNED_INTEGER,
  UNSIGNED_INTEGER,
  FLOATING_POINT
};

template <Category _CATEGORY, bool _PRIMITIVE, bool _NULL_TYPE, typename _UnsignedBits, typename T>
struct BaseTraits
{
  static constexpr Category CATEGORY = _CATEGORY;
  enum
  {
    PRIMITIVE = _PRIMITIVE,
    NULL_TYPE = _NULL_TYPE,
  };
};

template <typename _UnsignedBits, typename T>
struct BaseTraits<UNSIGNED_INTEGER, true, false, _UnsignedBits, T>
{
  using UnsignedBits = _UnsignedBits;

  static constexpr Category CATEGORY       = UNSIGNED_INTEGER;
  static constexpr UnsignedBits LOWEST_KEY = UnsignedBits(0);
  static constexpr UnsignedBits MAX_KEY    = UnsignedBits(-1);

  enum
  {
    PRIMITIVE = true,
    NULL_TYPE = false,
  };

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleIn(UnsignedBits key)
  {
    return key;
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleOut(UnsignedBits key)
  {
    return key;
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Max()
  {
    UnsignedBits retval_bits = MAX_KEY;
    T retval;
    memcpy(&retval, &retval_bits, sizeof(T));
    return retval;
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Lowest()
  {
    UnsignedBits retval_bits = LOWEST_KEY;
    T retval;
    memcpy(&retval, &retval_bits, sizeof(T));
    return retval;
  }
};

template <typename _UnsignedBits, typename T>
struct BaseTraits<SIGNED_INTEGER, true, false, _UnsignedBits, T>
{
  using UnsignedBits = _UnsignedBits;

  static constexpr Category CATEGORY       = SIGNED_INTEGER;
  static constexpr UnsignedBits HIGH_BIT   = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1);
  static constexpr UnsignedBits LOWEST_KEY = HIGH_BIT;
  static constexpr UnsignedBits MAX_KEY    = UnsignedBits(-1) ^ HIGH_BIT;

  enum
  {
    PRIMITIVE = true,
    NULL_TYPE = false,
  };

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleIn(UnsignedBits key)
  {
    return key ^ HIGH_BIT;
  };

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleOut(UnsignedBits key)
  {
    return key ^ HIGH_BIT;
  };

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Max()
  {
    UnsignedBits retval = MAX_KEY;
    return reinterpret_cast<T&>(retval);
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Lowest()
  {
    UnsignedBits retval = LOWEST_KEY;
    return reinterpret_cast<T&>(retval);
  }
};

template <typename _T>
struct FpLimits;

template <>
struct FpLimits<float>
{
  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE float Max()
  {
    return ::cuda::std::numeric_limits<float>::max();
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE float Lowest()
  {
    return ::cuda::std::numeric_limits<float>::lowest();
  }
};

template <>
struct FpLimits<double>
{
  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE double Max()
  {
    return ::cuda::std::numeric_limits<double>::max();
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE double Lowest()
  {
    return ::cuda::std::numeric_limits<double>::lowest();
  }
};

#  if defined(_CCCL_HAS_NVFP16)
template <>
struct FpLimits<__half>
{
  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE __half Max()
  {
    unsigned short max_word = 0x7BFF;
    return reinterpret_cast<__half&>(max_word);
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE __half Lowest()
  {
    unsigned short lowest_word = 0xFBFF;
    return reinterpret_cast<__half&>(lowest_word);
  }
};
#  endif // _CCCL_HAS_NVFP16

#  if defined(_CCCL_HAS_NVBF16)
template <>
struct FpLimits<__nv_bfloat16>
{
  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE __nv_bfloat16 Max()
  {
    unsigned short max_word = 0x7F7F;
    return reinterpret_cast<__nv_bfloat16&>(max_word);
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE __nv_bfloat16 Lowest()
  {
    unsigned short lowest_word = 0xFF7F;
    return reinterpret_cast<__nv_bfloat16&>(lowest_word);
  }
};
#  endif // _CCCL_HAS_NVBF16

#  if defined(__CUDA_FP8_TYPES_EXIST__)
template <>
struct FpLimits<__nv_fp8_e4m3>
{
  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE __nv_fp8_e4m3 Max()
  {
    unsigned char max_word = 0x7EU;
    __nv_fp8_e4m3 ret_val;
    memcpy(&ret_val, &max_word, sizeof(__nv_fp8_e4m3));
    return ret_val;
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE __nv_fp8_e4m3 Lowest()
  {
    unsigned char lowest_word = 0xFEU;
    __nv_fp8_e4m3 ret_val;
    memcpy(&ret_val, &lowest_word, sizeof(__nv_fp8_e4m3));
    return ret_val;
  }
};

template <>
struct FpLimits<__nv_fp8_e5m2>
{
  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE __nv_fp8_e5m2 Max()
  {
    unsigned char max_word = 0x7BU;
    __nv_fp8_e5m2 ret_val;
    memcpy(&ret_val, &max_word, sizeof(__nv_fp8_e5m2));
    return ret_val;
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE __nv_fp8_e5m2 Lowest()
  {
    unsigned char lowest_word = 0xFBU;
    __nv_fp8_e5m2 ret_val;
    memcpy(&ret_val, &lowest_word, sizeof(__nv_fp8_e5m2));
    return ret_val;
  }
};

#  endif // __CUDA_FP8_TYPES_EXIST__

template <typename _UnsignedBits, typename T>
struct BaseTraits<FLOATING_POINT, true, false, _UnsignedBits, T>
{
  using UnsignedBits = _UnsignedBits;

  static constexpr Category CATEGORY       = FLOATING_POINT;
  static constexpr UnsignedBits HIGH_BIT   = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1);
  static constexpr UnsignedBits LOWEST_KEY = UnsignedBits(-1);
  static constexpr UnsignedBits MAX_KEY    = UnsignedBits(-1) ^ HIGH_BIT;

  enum
  {
    PRIMITIVE = true,
    NULL_TYPE = false,
  };

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleIn(UnsignedBits key)
  {
    UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT;
    return key ^ mask;
  };

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleOut(UnsignedBits key)
  {
    UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1);
    return key ^ mask;
  };

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Max()
  {
    return FpLimits<T>::Max();
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Lowest()
  {
    return FpLimits<T>::Lowest();
  }
};

// clang-format off
template <typename T> struct NumericTraits :            BaseTraits<NOT_A_NUMBER, false, false, T, T> {};

template <> struct NumericTraits<NullType> :            BaseTraits<NOT_A_NUMBER, false, true, NullType, NullType> {};

template <> struct NumericTraits<char> :                BaseTraits<(::cuda::std::numeric_limits<char>::is_signed) ? SIGNED_INTEGER : UNSIGNED_INTEGER, true, false, unsigned char, char> {};
template <> struct NumericTraits<signed char> :         BaseTraits<SIGNED_INTEGER, true, false, unsigned char, signed char> {};
template <> struct NumericTraits<short> :               BaseTraits<SIGNED_INTEGER, true, false, unsigned short, short> {};
template <> struct NumericTraits<int> :                 BaseTraits<SIGNED_INTEGER, true, false, unsigned int, int> {};
template <> struct NumericTraits<long> :                BaseTraits<SIGNED_INTEGER, true, false, unsigned long, long> {};
template <> struct NumericTraits<long long> :           BaseTraits<SIGNED_INTEGER, true, false, unsigned long long, long long> {};

template <> struct NumericTraits<unsigned char> :       BaseTraits<UNSIGNED_INTEGER, true, false, unsigned char, unsigned char> {};
template <> struct NumericTraits<unsigned short> :      BaseTraits<UNSIGNED_INTEGER, true, false, unsigned short, unsigned short> {};
template <> struct NumericTraits<unsigned int> :        BaseTraits<UNSIGNED_INTEGER, true, false, unsigned int, unsigned int> {};
template <> struct NumericTraits<unsigned long> :       BaseTraits<UNSIGNED_INTEGER, true, false, unsigned long, unsigned long> {};
template <> struct NumericTraits<unsigned long long> :  BaseTraits<UNSIGNED_INTEGER, true, false, unsigned long long, unsigned long long> {};

#if CUB_IS_INT128_ENABLED
template <>
struct NumericTraits<__uint128_t>
{
  using T = __uint128_t;
  using UnsignedBits = __uint128_t;

  static constexpr Category       CATEGORY    = UNSIGNED_INTEGER;
  static constexpr UnsignedBits   LOWEST_KEY  = UnsignedBits(0);
  static constexpr UnsignedBits   MAX_KEY     = UnsignedBits(-1);

  static constexpr bool PRIMITIVE = false;
  static constexpr bool NULL_TYPE = false;

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleIn(UnsignedBits key)
  {
    return key;
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleOut(UnsignedBits key)
  {
    return key;
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Max()
  {
    return MAX_KEY;
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Lowest()
  {
    return LOWEST_KEY;
  }
};

template <>
struct NumericTraits<__int128_t>
{
  using T = __int128_t;
  using UnsignedBits = __uint128_t;

  static constexpr Category       CATEGORY    = SIGNED_INTEGER;
  static constexpr UnsignedBits   HIGH_BIT    = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1);
  static constexpr UnsignedBits   LOWEST_KEY  = HIGH_BIT;
  static constexpr UnsignedBits   MAX_KEY     = UnsignedBits(-1) ^ HIGH_BIT;

  static constexpr bool PRIMITIVE = false;
  static constexpr bool NULL_TYPE = false;

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleIn(UnsignedBits key)
  {
    return key ^ HIGH_BIT;
  };

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE UnsignedBits TwiddleOut(UnsignedBits key)
  {
    return key ^ HIGH_BIT;
  };

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Max()
  {
    UnsignedBits retval = MAX_KEY;
    return reinterpret_cast<T&>(retval);
  }

  static _CCCL_HOST_DEVICE _CCCL_FORCEINLINE T Lowest()
  {
    UnsignedBits retval = LOWEST_KEY;
    return reinterpret_cast<T&>(retval);
  }
};
#endif

template <> struct NumericTraits<float> :               BaseTraits<FLOATING_POINT, true, false, unsigned int, float> {};
template <> struct NumericTraits<double> :              BaseTraits<FLOATING_POINT, true, false, unsigned long long, double> {};
#  if defined(_CCCL_HAS_NVFP16)
    template <> struct NumericTraits<__half> :          BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {};
#  endif // _CCCL_HAS_NVFP16
#  if defined(_CCCL_HAS_NVBF16)
    template <> struct NumericTraits<__nv_bfloat16> :   BaseTraits<FLOATING_POINT, true, false, unsigned short, __nv_bfloat16> {};
#  endif // _CCCL_HAS_NVBF16

#if defined(__CUDA_FP8_TYPES_EXIST__)
    template <> struct NumericTraits<__nv_fp8_e4m3> :   BaseTraits<FLOATING_POINT, true, false, __nv_fp8_storage_t, __nv_fp8_e4m3> {};
    template <> struct NumericTraits<__nv_fp8_e5m2> :   BaseTraits<FLOATING_POINT, true, false, __nv_fp8_storage_t, __nv_fp8_e5m2> {};
#endif // __CUDA_FP8_TYPES_EXIST__

template <> struct NumericTraits<bool> :                BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {};
// clang-format on

template <typename T>
struct Traits : NumericTraits<typename ::cuda::std::remove_cv<T>::type>
{};

#endif // DOXYGEN_SHOULD_SKIP_THIS

CUB_NAMESPACE_END