cub::DevicePartition
Defined in cub/device/device_partition.cuh
-
struct DevicePartition
DevicePartition provides device-wide, parallel operations for partitioning sequences of data items residing within device-accessible memory.
Overview
These operations apply a selection criterion to construct a partitioned output sequence from items selected/unselected from a specified input sequence.
Usage Considerations
Dynamic parallelism. DevicePartition methods can be called within kernel code on devices in which CUDA dynamic parallelism is supported.
Performance
The work-complexity of partition as a function of input size is linear, resulting in performance throughput that plateaus with problem sizes large enough to saturate the GPU.
Public Static Functions
-
template<typename InputIteratorT, typename FlagIterator, typename OutputIteratorT, typename NumSelectedIteratorT, typename NumItemsT>
static inline cudaError_t Flagged(void *d_temp_storage, size_t &temp_storage_bytes, InputIteratorT d_in, FlagIterator d_flags, OutputIteratorT d_out, NumSelectedIteratorT d_num_selected_out, NumItemsT num_items, cudaStream_t stream = 0) Uses the
d_flags
sequence to split the corresponding items fromd_in
into a partitioned sequenced_out
. The total number of items copied into the first partition is written tod_num_selected_out
.The value type of
d_flags
must be castable tobool
(e.g.,bool
,char
,int
, etc.).Copies of the selected items are compacted into
d_out
and maintain their original relative ordering, however copies of the unselected items are compacted into the rear ofd_out
in reverse order.The range
[d_out, d_out + num_items)
shall not overlap[d_in, d_in + num_items)
nor[d_flags, d_flags + num_items)
in any way. The range[d_in, d_in + num_items)
may overlap[d_flags, d_flags + num_items)
.When
d_temp_storage
isnullptr
, no work is done and the required allocation size is returned intemp_storage_bytes
.
Snippet
The code snippet below illustrates the compaction of items selected from an
int
device vector.#include <cub/cub.cuh> // or equivalently <cub/device/device_partition.cuh> // Declare, allocate, and initialize device-accessible pointers for // input, flags, and output int num_items; // e.g., 8 int *d_in; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] char *d_flags; // e.g., [1, 0, 0, 1, 0, 1, 1, 0] int *d_out; // e.g., [ , , , , , , , ] int *d_num_selected_out; // e.g., [ ] ... // Determine temporary device storage requirements void *d_temp_storage = nullptr; std::size_t temp_storage_bytes = 0; cub::DevicePartition::Flagged( d_temp_storage, temp_storage_bytes, d_in, d_flags, d_out, d_num_selected_out, num_items); // Allocate temporary storage cudaMalloc(&d_temp_storage, temp_storage_bytes); // Run selection cub::DevicePartition::Flagged( d_temp_storage, temp_storage_bytes, d_in, d_flags, d_out, d_num_selected_out, num_items); // d_out <-- [1, 4, 6, 7, 8, 5, 3, 2] // d_num_selected_out <-- [4]
- Template Parameters
InputIteratorT – [inferred] Random-access input iterator type for reading input items (may be a simple pointer type)
FlagIterator – [inferred] Random-access input iterator type for reading selection flags (may be a simple pointer type)
OutputIteratorT – [inferred] Random-access output iterator type for writing output items (may be a simple pointer type)
NumSelectedIteratorT – [inferred] Output iterator type for recording the number of items selected (may be a simple pointer type)
NumItemsT – [inferred] Type of num_items
- Parameters
d_temp_storage – [in] Device-accessible allocation of temporary storage. When
nullptr
, the required allocation size is written totemp_storage_bytes
and no work is done.temp_storage_bytes – [inout] Reference to size in bytes of
d_temp_storage
allocationd_in – [in] Pointer to the input sequence of data items
d_flags – [in] Pointer to the input sequence of selection flags
d_out – [out] Pointer to the output sequence of partitioned data items
d_num_selected_out – [out] Pointer to the output total number of items selected (i.e., the offset of the unselected partition)
num_items – [in] Total number of items to select from
stream – [in]
[optional] CUDA stream to launch kernels within. Default is stream0.
-
template<typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT, typename SelectOp, typename NumItemsT>
static inline cudaError_t If(void *d_temp_storage, size_t &temp_storage_bytes, InputIteratorT d_in, OutputIteratorT d_out, NumSelectedIteratorT d_num_selected_out, NumItemsT num_items, SelectOp select_op, cudaStream_t stream = 0) Uses the
select_op
functor to split the corresponding items fromd_in
into a partitioned sequenced_out
. The total number of items copied into the first partition is written tod_num_selected_out
.Copies of the selected items are compacted into
d_out
and maintain their original relative ordering, however copies of the unselected items are compacted into the rear ofd_out
in reverse order.The range
[d_out, d_out + num_items)
shall not overlap[d_in, d_in + num_items)
in any way.When
d_temp_storage
isnullptr
, no work is done and the required allocation size is returned intemp_storage_bytes
.
Snippet
The code snippet below illustrates the compaction of items selected from an
int
device vector.#include <cub/cub.cuh> // or equivalently <cub/device/device_partition.cuh> // Functor type for selecting values less than some criteria struct LessThan { int compare; CUB_RUNTIME_FUNCTION __forceinline__ explicit LessThan(int compare) : compare(compare) {} CUB_RUNTIME_FUNCTION __forceinline__ bool operator()(const int &a) const { return (a < compare); } }; // Declare, allocate, and initialize device-accessible pointers for // input and output int num_items; // e.g., 8 int *d_in; // e.g., [0, 2, 3, 9, 5, 2, 81, 8] int *d_out; // e.g., [ , , , , , , , ] int *d_num_selected_out; // e.g., [ ] LessThan select_op(7); ... // Determine temporary device storage requirements void *d_temp_storage = nullptr; std::size_t temp_storage_bytes = 0; cub::DevicePartition::If( d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op); // Allocate temporary storage cudaMalloc(&d_temp_storage, temp_storage_bytes); // Run selection cub::DevicePartition::If( d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op); // d_out <-- [0, 2, 3, 5, 2, 8, 81, 9] // d_num_selected_out <-- [5]
- Template Parameters
InputIteratorT – [inferred] Random-access input iterator type for reading input items (may be a simple pointer type)
OutputIteratorT – [inferred] Random-access output iterator type for writing output items (may be a simple pointer type)
NumSelectedIteratorT – [inferred] Output iterator type for recording the number of items selected (may be a simple pointer type)
SelectOp – [inferred] Selection functor type having member
bool operator()(const T &a)
NumItemsT – [inferred] Type of num_items
- Parameters
d_temp_storage – [in] Device-accessible allocation of temporary storage. When
nullptr
, the required allocation size is written totemp_storage_bytes
and no work is done.temp_storage_bytes – [inout] Reference to size in bytes of
d_temp_storage
allocationd_in – [in] Pointer to the input sequence of data items
d_out – [out] Pointer to the output sequence of partitioned data items
d_num_selected_out – [out] Pointer to the output total number of items selected (i.e., the offset of the unselected partition)
num_items – [in] Total number of items to select from
select_op – [in] Unary selection operator
stream – [in]
[optional] CUDA stream to launch kernels within. Default is stream0.
-
template<typename InputIteratorT, typename FirstOutputIteratorT, typename SecondOutputIteratorT, typename UnselectedOutputIteratorT, typename NumSelectedIteratorT, typename SelectFirstPartOp, typename SelectSecondPartOp>
static inline cudaError_t If(void *d_temp_storage, std::size_t &temp_storage_bytes, InputIteratorT d_in, FirstOutputIteratorT d_first_part_out, SecondOutputIteratorT d_second_part_out, UnselectedOutputIteratorT d_unselected_out, NumSelectedIteratorT d_num_selected_out, int num_items, SelectFirstPartOp select_first_part_op, SelectSecondPartOp select_second_part_op, cudaStream_t stream = 0) Uses two functors to split the corresponding items from
d_in
into a three partitioned sequencesd_first_part_out
,d_second_part_out
, andd_unselected_out
. The total number of items copied into the first partition is written tod_num_selected_out[0]
, while the total number of items copied into the second partition is written tod_num_selected_out[1]
.Copies of the items selected by
select_first_part_op
are compacted intod_first_part_out
and maintain their original relative ordering.Copies of the items selected by
select_second_part_op
are compacted intod_second_part_out
and maintain their original relative ordering.Copies of the unselected items are compacted into the
d_unselected_out
in reverse order.The ranges
[d_out, d_out + num_items)
,[d_first_part_out, d_first_part_out + d_num_selected_out[0])
,[d_second_part_out, d_second_part_out + d_num_selected_out[1])
,[d_unselected_out, d_unselected_out + num_items - d_num_selected_out[0] - d_num_selected_out[1])
, shall not overlap in any way.
Snippet
The code snippet below illustrates how this algorithm can partition an input vector into small, medium, and large items so that the relative order of items remain deterministic.
Let’s consider any value that doesn’t exceed six a small one. On the other hand, any value that exceeds 50 will be considered a large one. Since the value used to define a small part doesn’t match one that defines the large part, the intermediate segment is implied.
These definitions partition a value space into three categories. We want to preserve the order of items in which they appear in the input vector. Since the algorithm provides stable partitioning, this is possible.
Since the number of items in each category is unknown beforehand, we need three output arrays of num_items elements each. To reduce the memory requirements, we can combine the output storage for two categories.
Since each value falls precisely in one category, it’s safe to add “large” values into the head of the shared output vector and the “middle” values into its tail. To add items into the tail of the output array, we can use
thrust::reverse_iterator
.#include <cub/cub.cuh> // or equivalently <cub/device/device_partition.cuh> // Functor type for selecting values less than some criteria struct LessThan { int compare; __host__ __device__ __forceinline__ explicit LessThan(int compare) : compare(compare) {} __host__ __device__ __forceinline__ bool operator()(const int &a) const { return a < compare; } }; // Functor type for selecting values greater than some criteria struct GreaterThan { int compare; __host__ __device__ __forceinline__ explicit GreaterThan(int compare) : compare(compare) {} __host__ __device__ __forceinline__ bool operator()(const int &a) const { return a > compare; } }; // Declare, allocate, and initialize device-accessible pointers for // input and output int num_items; // e.g., 8 int *d_in; // e.g., [0, 2, 3, 9, 5, 2, 81, 8] int *d_large_and_unselected_out; // e.g., [ , , , , , , , ] int *d_small_out; // e.g., [ , , , , , , , ] int *d_num_selected_out; // e.g., [ , ] thrust::reverse_iterator<T> unselected_out(d_large_and_unselected_out + num_items); LessThan small_items_selector(7); GreaterThan large_items_selector(50); ... // Determine temporary device storage requirements void *d_temp_storage = nullptr; std::size_t temp_storage_bytes = 0; cub::DevicePartition::If( d_temp_storage, temp_storage_bytes, d_in, d_large_and_medium_out, d_small_out, unselected_out, d_num_selected_out, num_items, large_items_selector, small_items_selector); // Allocate temporary storage cudaMalloc(&d_temp_storage, temp_storage_bytes); // Run selection cub::DevicePartition::If( d_temp_storage, temp_storage_bytes, d_in, d_large_and_medium_out, d_small_out, unselected_out, d_num_selected_out, num_items, large_items_selector, small_items_selector); // d_large_and_unselected_out <-- [ 81, , , , , , 8, 9 ] // d_small_out <-- [ 0, 2, 3, 5, 2, , , ] // d_num_selected_out <-- [ 1, 5 ]
- Template Parameters
InputIteratorT – [inferred] Random-access input iterator type for reading input items (may be a simple pointer type)
FirstOutputIteratorT – [inferred] Random-access output iterator type for writing output items selected by first operator (may be a simple pointer type)
SecondOutputIteratorT – [inferred] Random-access output iterator type for writing output items selected by second operator (may be a simple pointer type)
UnselectedOutputIteratorT – [inferred] Random-access output iterator type for writing unselected items (may be a simple pointer type)
NumSelectedIteratorT – [inferred] Output iterator type for recording the number of items selected (may be a simple pointer type)
SelectFirstPartOp – [inferred] Selection functor type having member
bool operator()(const T &a)
SelectSecondPartOp – [inferred] Selection functor type having member
bool operator()(const T &a)
- Parameters
d_temp_storage – [in] Device-accessible allocation of temporary storage. When
nullptr
, the required allocation size is written totemp_storage_bytes
and no work is done.temp_storage_bytes – [inout] Reference to size in bytes of
d_temp_storage
allocationd_in – [in] Pointer to the input sequence of data items
d_first_part_out – [out] Pointer to the output sequence of data items selected by
select_first_part_op
d_second_part_out – [out] Pointer to the output sequence of data items selected by
select_second_part_op
d_unselected_out – [out] Pointer to the output sequence of unselected data items
d_num_selected_out – [out] Pointer to the output array with two elements, where total number of items selected by
select_first_part_op
is stored asd_num_selected_out[0]
and total number of items selected byselect_second_part_op
is stored asd_num_selected_out[1]
, respectivelynum_items – [in] Total number of items to select from
select_first_part_op – [in] Unary selection operator to select
d_first_part_out
select_second_part_op – [in] Unary selection operator to select
d_second_part_out
stream – [in]
[optional] CUDA stream to launch kernels within. Default is stream0.