cub/block/block_merge_sort.cuh
File members: cub/block/block_merge_sort.cuh
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cub/config.cuh>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <cub/thread/thread_sort.cuh>
#include <cub/util_math.cuh>
#include <cub/util_namespace.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
#include <cuda/std/type_traits>
CUB_NAMESPACE_BEGIN
// This implements the DiagonalIntersection algorithm from Merge-Path. Additional details can be found in:
// * S. Odeh, O. Green, Z. Mwassi, O. Shmueli, Y. Birk, "Merge Path - Parallel Merging Made Simple", Multithreaded
// Architectures and Applications (MTAAP) Workshop, IEEE 26th International Parallel & Distributed Processing
// Symposium (IPDPS), 2012
// * S. Odeh, O. Green, Y. Birk, "Merge Path - A Visually Intuitive Approach to Parallel Merging", 2014, URL:
// https://arxiv.org/abs/1406.2628
template <typename KeyIt1, typename KeyIt2, typename OffsetT, typename BinaryPred>
_CCCL_DEVICE _CCCL_FORCEINLINE OffsetT
MergePath(KeyIt1 keys1, KeyIt2 keys2, OffsetT keys1_count, OffsetT keys2_count, OffsetT diag, BinaryPred binary_pred)
{
OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count;
OffsetT keys1_end = (cub::min)(diag, keys1_count);
while (keys1_begin < keys1_end)
{
const OffsetT mid = cub::MidPoint<OffsetT>(keys1_begin, keys1_end);
// pull copies of the keys before calling binary_pred so proxy references are unwrapped
const detail::value_t<KeyIt1> key1 = keys1[mid];
const detail::value_t<KeyIt2> key2 = keys2[diag - 1 - mid];
if (binary_pred(key2, key1))
{
keys1_end = mid;
}
else
{
keys1_begin = mid + 1;
}
}
return keys1_begin;
}
template <typename KeyIt, typename KeyT, typename CompareOp, int ITEMS_PER_THREAD>
_CCCL_DEVICE _CCCL_FORCEINLINE void SerialMerge(
KeyIt keys_shared,
int keys1_beg,
int keys2_beg,
int keys1_count,
int keys2_count,
KeyT (&output)[ITEMS_PER_THREAD],
int (&indices)[ITEMS_PER_THREAD],
CompareOp compare_op)
{
const int keys1_end = keys1_beg + keys1_count;
const int keys2_end = keys2_beg + keys2_count;
KeyT key1 = keys_shared[keys1_beg];
KeyT key2 = keys_shared[keys2_beg];
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
const bool p = (keys2_beg < keys2_end) && ((keys1_beg >= keys1_end) || compare_op(key2, key1));
output[item] = p ? key2 : key1;
indices[item] = p ? keys2_beg++ : keys1_beg++;
if (p)
{
key2 = keys_shared[keys2_beg];
}
else
{
key1 = keys_shared[keys1_beg];
}
}
}
template <typename KeyT, typename ValueT, int NUM_THREADS, int ITEMS_PER_THREAD, typename SynchronizationPolicy>
class BlockMergeSortStrategy
{
static_assert(PowerOfTwo<NUM_THREADS>::VALUE, "NUM_THREADS must be a power of two");
private:
static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * NUM_THREADS;
// Whether or not there are values to be trucked along with keys
static constexpr bool KEYS_ONLY = ::cuda::std::is_same<ValueT, NullType>::value;
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
union _TempStorage
{
KeyT keys_shared[ITEMS_PER_TILE + 1];
ValueT items_shared[ITEMS_PER_TILE + 1];
}; // union TempStorage
#endif // DOXYGEN_SHOULD_SKIP_THIS
_TempStorage& temp_storage;
_CCCL_DEVICE _CCCL_FORCEINLINE _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;
return private_storage;
}
const unsigned int linear_tid;
public:
struct TempStorage : Uninitialized<_TempStorage>
{};
BlockMergeSortStrategy() = delete;
explicit _CCCL_DEVICE _CCCL_FORCEINLINE BlockMergeSortStrategy(unsigned int linear_tid)
: temp_storage(PrivateStorage())
, linear_tid(linear_tid)
{}
_CCCL_DEVICE _CCCL_FORCEINLINE BlockMergeSortStrategy(TempStorage& temp_storage, unsigned int linear_tid)
: temp_storage(temp_storage.Alias())
, linear_tid(linear_tid)
{}
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int get_linear_tid() const
{
return linear_tid;
}
template <typename CompareOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void Sort(KeyT (&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
{
ValueT items[ITEMS_PER_THREAD];
Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
}
template <typename CompareOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Sort(KeyT (&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
{
ValueT items[ITEMS_PER_THREAD];
Sort<CompareOp, true>(keys, items, compare_op, valid_items, oob_default);
}
template <typename CompareOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Sort(KeyT (&keys)[ITEMS_PER_THREAD], ValueT (&items)[ITEMS_PER_THREAD], CompareOp compare_op)
{
Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
}
template <typename CompareOp, bool IS_LAST_TILE = true>
_CCCL_DEVICE _CCCL_FORCEINLINE void
Sort(KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
if (IS_LAST_TILE)
{
// if last tile, find valid max_key
// and fill the remaining keys with it
//
KeyT max_key = oob_default;
#pragma unroll
for (int item = 1; item < ITEMS_PER_THREAD; ++item)
{
if (ITEMS_PER_THREAD * linear_tid + item < valid_items)
{
max_key = compare_op(max_key, keys[item]) ? keys[item] : max_key;
}
else
{
keys[item] = max_key;
}
}
}
// if first element of thread is in input range, stable sort items
//
if (!IS_LAST_TILE || ITEMS_PER_THREAD * linear_tid < valid_items)
{
StableOddEvenSort(keys, items, compare_op);
}
// each thread has sorted keys
// merge sort keys in shared memory
//
for (int target_merged_threads_number = 2; target_merged_threads_number <= NUM_THREADS;
target_merged_threads_number *= 2)
{
int merged_threads_number = target_merged_threads_number / 2;
int mask = target_merged_threads_number - 1;
Sync();
// store keys in shmem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = ITEMS_PER_THREAD * linear_tid + item;
temp_storage.keys_shared[idx] = keys[item];
}
Sync();
int indices[ITEMS_PER_THREAD];
int first_thread_idx_in_thread_group_being_merged = ~mask & linear_tid;
int start = ITEMS_PER_THREAD * first_thread_idx_in_thread_group_being_merged;
int size = ITEMS_PER_THREAD * merged_threads_number;
int thread_idx_in_thread_group_being_merged = mask & linear_tid;
int diag = (cub::min)(valid_items, ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged);
int keys1_beg = (cub::min)(valid_items, start);
int keys1_end = (cub::min)(valid_items, keys1_beg + size);
int keys2_beg = keys1_end;
int keys2_end = (cub::min)(valid_items, keys2_beg + size);
int keys1_count = keys1_end - keys1_beg;
int keys2_count = keys2_end - keys2_beg;
int partition_diag = MergePath(
&temp_storage.keys_shared[keys1_beg],
&temp_storage.keys_shared[keys2_beg],
keys1_count,
keys2_count,
diag,
compare_op);
int keys1_beg_loc = keys1_beg + partition_diag;
int keys1_end_loc = keys1_end;
int keys2_beg_loc = keys2_beg + diag - partition_diag;
int keys2_end_loc = keys2_end;
int keys1_count_loc = keys1_end_loc - keys1_beg_loc;
int keys2_count_loc = keys2_end_loc - keys2_beg_loc;
SerialMerge(
&temp_storage.keys_shared[0],
keys1_beg_loc,
keys2_beg_loc,
keys1_count_loc,
keys2_count_loc,
keys,
indices,
compare_op);
if (!KEYS_ONLY)
{
Sync();
// store keys in shmem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
int idx = ITEMS_PER_THREAD * linear_tid + item;
temp_storage.items_shared[idx] = items[item];
}
Sync();
// gather items from shmem
//
#pragma unroll
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
{
items[item] = temp_storage.items_shared[indices[item]];
}
}
}
} // func block_merge_sort
template <typename CompareOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void StableSort(KeyT (&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
{
Sort(keys, compare_op);
}
template <typename CompareOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void
StableSort(KeyT (&keys)[ITEMS_PER_THREAD], ValueT (&items)[ITEMS_PER_THREAD], CompareOp compare_op)
{
Sort(keys, items, compare_op);
}
template <typename CompareOp>
_CCCL_DEVICE _CCCL_FORCEINLINE void
StableSort(KeyT (&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
{
Sort(keys, compare_op, valid_items, oob_default);
}
template <typename CompareOp, bool IS_LAST_TILE = true>
_CCCL_DEVICE _CCCL_FORCEINLINE void StableSort(
KeyT (&keys)[ITEMS_PER_THREAD],
ValueT (&items)[ITEMS_PER_THREAD],
CompareOp compare_op,
int valid_items,
KeyT oob_default)
{
Sort<CompareOp, IS_LAST_TILE>(keys, items, compare_op, valid_items, oob_default);
}
private:
_CCCL_DEVICE _CCCL_FORCEINLINE void Sync() const
{
static_cast<const SynchronizationPolicy*>(this)->SyncImplementation();
}
};
template <typename KeyT,
int BLOCK_DIM_X,
int ITEMS_PER_THREAD,
typename ValueT = NullType,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1>
class BlockMergeSort
: public BlockMergeSortStrategy<
KeyT,
ValueT,
BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
ITEMS_PER_THREAD,
BlockMergeSort<KeyT, BLOCK_DIM_X, ITEMS_PER_THREAD, ValueT, BLOCK_DIM_Y, BLOCK_DIM_Z>>
{
private:
// The thread block size in threads
static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z;
static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * BLOCK_THREADS;
using BlockMergeSortStrategyT = BlockMergeSortStrategy<KeyT, ValueT, BLOCK_THREADS, ITEMS_PER_THREAD, BlockMergeSort>;
public:
_CCCL_DEVICE _CCCL_FORCEINLINE BlockMergeSort()
: BlockMergeSortStrategyT(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
_CCCL_DEVICE _CCCL_FORCEINLINE explicit BlockMergeSort(typename BlockMergeSortStrategyT::TempStorage& temp_storage)
: BlockMergeSortStrategyT(temp_storage, RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
{}
private:
_CCCL_DEVICE _CCCL_FORCEINLINE void SyncImplementation() const
{
CTA_SYNC();
}
friend BlockMergeSortStrategyT;
};
CUB_NAMESPACE_END