cub::WarpMergeSort

Defined in cub/warp/warp_merge_sort.cuh

template<typename KeyT, int ITEMS_PER_THREAD, int LOGICAL_WARP_THREADS = detail::warp_threads, typename ValueT = NullType>
class WarpMergeSort : public cub::BlockMergeSortStrategy<KeyT, NullType, detail::warp_threads, ITEMS_PER_THREAD, WarpMergeSort<KeyT, ITEMS_PER_THREAD, detail::warp_threads, NullType>>

The WarpMergeSort class provides methods for sorting items partitioned across a CUDA warp using a merge sorting method.

Overview

WarpMergeSort arranges items into ascending order using a comparison functor with less-than semantics. Merge sort can handle arbitrary types and comparison functors.

A Simple Example

The code snippet below illustrates a sort of 64 integer keys that are partitioned across 16 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>  // or equivalently <cub/warp/warp_merge_sort.cuh>

struct CustomLess
{
  template <typename DataType>
  __device__ bool operator()(const DataType &lhs, const DataType &rhs)
  {
    return lhs < rhs;
  }
};

__global__ void ExampleKernel(...)
{
    constexpr int warp_threads = 16;
    constexpr int block_threads = 256;
    constexpr int items_per_thread = 4;
    constexpr int warps_per_block = block_threads / warp_threads;
    const int warp_id = static_cast<int>(threadIdx.x) / warp_threads;

    // Specialize WarpMergeSort for a virtual warp of 16 threads
    // owning 4 integer items each
    using WarpMergeSortT =
      cub::WarpMergeSort<int, items_per_thread, warp_threads>;

    // Allocate shared memory for WarpMergeSort
    __shared__ typename WarpMergeSortT::TempStorage temp_storage[warps_per_block];

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_keys[items_per_thread];
    // ...

    WarpMergeSortT(temp_storage[warp_id]).Sort(thread_keys, CustomLess());
    // ...
}

Suppose the set of input thread_keys across a warp of threads is { [0,64,1,63], [2,62,3,61], [4,60,5,59], ..., [31,34,32,33] }. The corresponding output thread_keys in those threads will be { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [31,32,33,34] }.

Template Parameters
  • KeyT – Key type

  • ITEMS_PER_THREAD – The number of items per thread

  • LOGICAL_WARP_THREADS[optional] The number of threads per “logical” warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute-capability (e.g., 32 threads for SM86). Must be a power of two.

  • ValueT[optional] Value type (default: cub::NullType, which indicates a keys-only sort)

Public Functions

WarpMergeSort() = delete
inline WarpMergeSort(typename BlockMergeSortStrategyT::TempStorage &temp_storage)
inline unsigned int get_member_mask() const