cub::WarpReduce
Defined in cub/warp/warp_reduce.cuh
-
template<typename T, int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS, int LEGACY_PTX_ARCH = 0>
class WarpReduce The
WarpReduce
class provides collective methods for computing a parallel reduction of items partitioned across a CUDA thread warp.Overview
A reduction (or fold) uses a binary combining operator to compute a single aggregate from a list of input elements.
Supports “logical” warps smaller than the physical warp size (e.g., logical warps of 8 threads)
The number of entrant threads must be an multiple of
LOGICAL_WARP_THREADS
Performance Considerations
Uses special instructions when applicable (e.g., warp
SHFL
instructions)Uses synchronization-free communication between warp lanes when applicable
Incurs zero bank conflicts for most types
Computation is slightly more efficient (i.e., having lower instruction overhead) for:
Summation (vs. generic reduction)
The architecture’s warp size is a whole multiple of
LOGICAL_WARP_THREADS
Simple Examples
Every thread in the warp uses the WarpReduce class by first specializing the WarpReduce type, then instantiating an instance with parameters for communication, and finally invoking or more collective member functions.
The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps).
#include <cub/cub.cuh> __global__ void ExampleKernel(...) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for 4 warps __shared__ typename WarpReduce::TempStorage temp_storage[4]; // Obtain one input item per thread int thread_data = ... // Return the warp-wide sums to each lane0 (threads 0, 32, 64, and 96) int warp_id = threadIdx.x / 32; int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data);
Suppose the set of input
thread_data
across the block of threads is{0, 1, 2, 3, ..., 127}
. The corresponding outputaggregate
in threads 0, 32, 64, and 96 will be496
,1520
,2544
, and3568
, respectively (and is undefined in other threads).The code snippet below illustrates a single warp sum reduction within a block of 128 threads.
#include <cub/cub.cuh> __global__ void ExampleKernel(...) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for one warp __shared__ typename WarpReduce::TempStorage temp_storage; ... // Only the first warp performs a reduction if (threadIdx.x < 32) { // Obtain one input item per thread int thread_data = ... // Return the warp-wide sum to lane0 int aggregate = WarpReduce(temp_storage).Sum(thread_data);
Suppose the set of input
thread_data
across the warp of threads is{0, 1, 2, 3, ..., 31}
. The corresponding outputaggregate
in thread0 will be496
(and is undefined in other threads).- Template Parameters
T – The reduction input/output element type
LOGICAL_WARP_THREADS – [optional] The number of threads per “logical” warp (may be less than the number of hardware warp threads). Default is the warp size of the targeted CUDA compute-capability (e.g., 32 threads for SM20).
LEGACY_PTX_ARCH – [optional] Unused.
Collective constructors
-
inline WarpReduce(TempStorage &temp_storage)
Collective constructor using the specified memory allocation as temporary storage. Logical warp and lane identifiers are constructed from
threadIdx.x
.- Parameters
temp_storage – [in] Reference to memory allocation having layout type TempStorage
Summation reductions
-
inline T Sum(T input)
Computes a warp-wide sum in the calling warp. The output is valid in warp lane0.
A subsequent
__syncwarp()
warp-wide barrier should be invoked after calling this method if the collective’s temporary storage (e.g.,temp_storage
) is to be reused or repurposed.Snippet
The code snippet below illustrates four concurrent warp sum reductions within a block of 128 threads (one per each of the 32-thread warps).
#include <cub/cub.cuh> __global__ void ExampleKernel(...) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for 4 warps __shared__ typename WarpReduce::TempStorage temp_storage[4]; // Obtain one input item per thread int thread_data = ... // Return the warp-wide sums to each lane0 int warp_id = threadIdx.x / 32; int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data);
Suppose the set of input
thread_data
across the block of threads is{0, 1, 2, 3, ..., 127}
. The corresponding outputaggregate
in threads 0, 32, 64, and 96 will496
,1520
,2544
, and3568
, respectively (and is undefined in other threads).- Parameters
input – [in] Calling thread’s input
-
inline T Sum(T input, int valid_items)
Computes a partially-full warp-wide sum in the calling warp. The output is valid in warp lane0.
All threads across the calling warp must agree on the same value for
valid_items
. Otherwise the result is undefined.A subsequent
__syncwarp()
warp-wide barrier should be invoked after calling this method if the collective’s temporary storage (e.g.,temp_storage
) is to be reused or repurposed.Snippet
The code snippet below illustrates a sum reduction within a single, partially-full block of 32 threads (one warp).
#include <cub/cub.cuh> __global__ void ExampleKernel(int *d_data, int valid_items) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for one warp __shared__ typename WarpReduce::TempStorage temp_storage; // Obtain one input item per thread if in range int thread_data; if (threadIdx.x < valid_items) thread_data = d_data[threadIdx.x]; // Return the warp-wide sums to each lane0 int aggregate = WarpReduce(temp_storage).Sum( thread_data, valid_items);
Suppose the input
d_data
is{0, 1, 2, 3, 4, ...
andvalid_items
is4
. The corresponding outputaggregate
in lane0 is6
(and is undefined in other threads).- Parameters
input – [in] Calling thread’s input
valid_items – [in] Total number of valid items in the calling thread’s logical warp (may be less than
LOGICAL_WARP_THREADS
)
-
template<typename FlagT>
inline T HeadSegmentedSum(T input, FlagT head_flag) Computes a segmented sum in the calling warp where segments are defined by head-flags. The sum of each segment is returned to the first lane in that segment (which always includes lane0).
A subsequent
__syncwarp()
warp-wide barrier should be invoked after calling this method if the collective’s temporary storage (e.g.,temp_storage
) is to be reused or repurposed.Snippet
The code snippet below illustrates a head-segmented warp sum reduction within a block of 32 threads (one warp).
#include <cub/cub.cuh> __global__ void ExampleKernel(...) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for one warp __shared__ typename WarpReduce::TempStorage temp_storage; // Obtain one input item and flag per thread int thread_data = ... int head_flag = ... // Return the warp-wide sums to each lane0 int aggregate = WarpReduce(temp_storage).HeadSegmentedSum( thread_data, head_flag);
Suppose the set of input
thread_data
andhead_flag
across the block of threads is{0, 1, 2, 3, ..., 31
and is{1, 0, 0, 0, 1, 0, 0, 0, ..., 1, 0, 0, 0
, respectively. The corresponding outputaggregate
in threads 0, 4, 8, etc. will be6
,22
,38
, etc. (and is undefined in other threads).- Template Parameters
ReductionOp – [inferred] Binary reduction operator type having member
T operator()(const T &a, const T &b)
- Parameters
input – [in] Calling thread’s input
head_flag – [in] Head flag denoting whether or not
input
is the start of a new segment
-
template<typename FlagT>
inline T TailSegmentedSum(T input, FlagT tail_flag) Computes a segmented sum in the calling warp where segments are defined by tail-flags. The sum of each segment is returned to the first lane in that segment (which always includes lane0).
A subsequent
__syncwarp()
warp-wide barrier should be invoked after calling this method if the collective’s temporary storage (e.g.,temp_storage
) is to be reused or repurposed.Snippet
The code snippet below illustrates a tail-segmented warp sum reduction within a block of 32 threads (one warp).
#include <cub/cub.cuh> __global__ void ExampleKernel(...) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for one warp __shared__ typename WarpReduce::TempStorage temp_storage; // Obtain one input item and flag per thread int thread_data = ... int tail_flag = ... // Return the warp-wide sums to each lane0 int aggregate = WarpReduce(temp_storage).TailSegmentedSum( thread_data, tail_flag);
Suppose the set of input
thread_data
andtail_flag
across the block of threads is{0, 1, 2, 3, ..., 31}
and is{0, 0, 0, 1, 0, 0, 0, 1, ..., 0, 0, 0, 1}
, respectively. The corresponding outputaggregate
in threads 0, 4, 8, etc. will be6
,22
,38
, etc. (and is undefined in other threads).- Template Parameters
ReductionOp – [inferred] Binary reduction operator type having member
T operator()(const T &a, const T &b)
- Parameters
input – [in] Calling thread’s input
tail_flag – [in] Head flag denoting whether or not
input
is the start of a new segment
Generic reductions
-
template<typename ReductionOp>
inline T Reduce(T input, ReductionOp reduction_op) Computes a warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp lane0.
Supports non-commutative reduction operators
A subsequent
__syncwarp()
warp-wide barrier should be invoked after calling this method if the collective’s temporary storage (e.g.,temp_storage
) is to be reused or repurposed.Snippet
The code snippet below illustrates four concurrent warp max reductions within a block of 128 threads (one per each of the 32-thread warps).
#include <cub/cub.cuh> __global__ void ExampleKernel(...) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for 4 warps __shared__ typename WarpReduce::TempStorage temp_storage[4]; // Obtain one input item per thread int thread_data = ... // Return the warp-wide reductions to each lane0 int warp_id = threadIdx.x / 32; int aggregate = WarpReduce(temp_storage[warp_id]).Reduce( thread_data, cuda::maximum<>{});
Suppose the set of input
thread_data
across the block of threads is{0, 1, 2, 3, ..., 127}
. The corresponding outputaggregate
in threads 0, 32, 64, and 96 will be31
,63
,95
, and127
, respectively (and is undefined in other threads).- Template Parameters
ReductionOp – [inferred] Binary reduction operator type having member
T operator()(const T &a, const T &b)
- Parameters
input – [in] Calling thread’s input
reduction_op – [in] Binary reduction operator
-
template<typename ReductionOp>
inline T Reduce(T input, ReductionOp reduction_op, int valid_items) Computes a partially-full warp-wide reduction in the calling warp using the specified binary reduction functor. The output is valid in warp lane0.
All threads across the calling warp must agree on the same value for
valid_items
. Otherwise the result is undefined.Supports non-commutative reduction operators
A subsequent
__syncwarp()
warp-wide barrier should be invoked after calling this method if the collective’s temporary storage (e.g.,temp_storage
) is to be reused or repurposed.Snippet
The code snippet below illustrates a max reduction within a single, partially-full block of 32 threads (one warp).
#include <cub/cub.cuh> __global__ void ExampleKernel(int *d_data, int valid_items) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for one warp __shared__ typename WarpReduce::TempStorage temp_storage; // Obtain one input item per thread if in range int thread_data; if (threadIdx.x < valid_items) thread_data = d_data[threadIdx.x]; // Return the warp-wide reductions to each lane0 int aggregate = WarpReduce(temp_storage).Reduce( thread_data, cuda::maximum<>{}, valid_items);
Suppose the input
d_data
is{0, 1, 2, 3, 4, ... }
andvalid_items
is4
. The corresponding outputaggregate
in thread0 is3
(and is undefined in other threads).- Template Parameters
ReductionOp – [inferred] Binary reduction operator type having member
T operator()(const T &a, const T &b)
- Parameters
input – [in] Calling thread’s input
reduction_op – [in] Binary reduction operator
valid_items – [in] Total number of valid items in the calling thread’s logical warp (may be less than
LOGICAL_WARP_THREADS
)
-
template<typename ReductionOp, typename FlagT>
inline T HeadSegmentedReduce(T input, FlagT head_flag, ReductionOp reduction_op) Computes a segmented reduction in the calling warp where segments are defined by head-flags. The reduction of each segment is returned to the first lane in that segment (which always includes lane0).
Supports non-commutative reduction operators
A subsequent
__syncwarp()
warp-wide barrier should be invoked after calling this method if the collective’s temporary storage (e.g.,temp_storage
) is to be reused or repurposed.Snippet
The code snippet below illustrates a head-segmented warp max reduction within a block of 32 threads (one warp).
#include <cub/cub.cuh> __global__ void ExampleKernel(...) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for one warp __shared__ typename WarpReduce::TempStorage temp_storage; // Obtain one input item and flag per thread int thread_data = ... int head_flag = ... // Return the warp-wide reductions to each lane0 int aggregate = WarpReduce(temp_storage).HeadSegmentedReduce( thread_data, head_flag, cuda::maximum<>{});
Suppose the set of input
thread_data
andhead_flag
across the block of threads is{0, 1, 2, 3, ..., 31}
and is{1, 0, 0, 0, 1, 0, 0, 0, ..., 1, 0, 0, 0}
, respectively. The corresponding outputaggregate
in threads 0, 4, 8, etc. will be3
,7
,11
, etc. (and is undefined in other threads).- Template Parameters
ReductionOp – [inferred] Binary reduction operator type having member
T operator()(const T &a, const T &b)
- Parameters
input – [in] Calling thread’s input
head_flag – [in] Head flag denoting whether or not
input
is the start of a new segmentreduction_op – [in] Reduction operator
-
template<typename ReductionOp, typename FlagT>
inline T TailSegmentedReduce(T input, FlagT tail_flag, ReductionOp reduction_op) Computes a segmented reduction in the calling warp where segments are defined by tail-flags. The reduction of each segment is returned to the first lane in that segment (which always includes lane0).
Supports non-commutative reduction operators
A subsequent
__syncwarp()
warp-wide barrier should be invoked after calling this method if the collective’s temporary storage (e.g.,temp_storage
) is to be reused or repurposed.Snippet
The code snippet below illustrates a tail-segmented warp max reduction within a block of 32 threads (one warp).
#include <cub/cub.cuh> __global__ void ExampleKernel(...) { // Specialize WarpReduce for type int using WarpReduce = cub::WarpReduce<int>; // Allocate WarpReduce shared memory for one warp __shared__ typename WarpReduce::TempStorage temp_storage; // Obtain one input item and flag per thread int thread_data = ... int tail_flag = ... // Return the warp-wide reductions to each lane0 int aggregate = WarpReduce(temp_storage).TailSegmentedReduce( thread_data, tail_flag, cuda::maximum<>{});
Suppose the set of input
thread_data
andtail_flag
across the block of threads is{0, 1, 2, 3, ..., 31}
and is{0, 0, 0, 1, 0, 0, 0, 1, ..., 0, 0, 0, 1}
, respectively. The corresponding outputaggregate
in threads 0, 4, 8, etc. will be3
,7
,11
, etc. (and is undefined in other threads).- Template Parameters
ReductionOp – [inferred] Binary reduction operator type having member
T operator()(const T &a, const T &b)
- Parameters
input – [in] Calling thread’s input
tail_flag – [in] Tail flag denoting whether or not
input
is the end of the current segmentreduction_op – [in] Reduction operator
-
struct TempStorage : public Uninitialized<_TempStorage>
The operations exposed by WarpReduce require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the
__shared__
keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) orunion
’d with other storage allocation types to facilitate memory reuse.