cub::BlockDiscontinuity

Defined in cub/block/block_discontinuity.cuh

template<typename T, int BLOCK_DIM_X, int BLOCK_DIM_Y = 1, int BLOCK_DIM_Z = 1, int LEGACY_PTX_ARCH = 0>
class BlockDiscontinuity

The BlockDiscontinuity class provides collective methods for flagging discontinuities within an ordered set of items partitioned across a CUDA thread block.

Overview

  • A set of “head flags” (or “tail flags”) is often used to indicate corresponding items that differ from their predecessors (or successors). For example, head flags are convenient for demarcating disjoint data segments as part of a segmented scan or reduction.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

Performance Considerations

  • Efficiency is increased with increased granularity ITEMS_PER_THREAD. Performance is also typically increased until the additional register pressure or shared memory allocation size causes SM occupancy to fall too low. Consider variants of cub::BlockLoad for efficiently gathering a blocked arrangement of elements across threads.

  • Incurs zero bank conflicts for most types

A Simple Example

Every thread in the block uses the BlockDiscontinuity class by first specializing the BlockDiscontinuity type, then instantiating an instance with parameters for communication, and finally invoking one or more collective member functions.

The code snippet below illustrates the head flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_discontinuity.cuh>

__global__ void ExampleKernel(...)
{
    // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int
    using BlockDiscontinuity = cub::BlockDiscontinuity<int, 128>;

    // Allocate shared memory for BlockDiscontinuity
    __shared__ typename BlockDiscontinuity::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Collectively compute head flags for discontinuities in the segment
    int head_flags[4];
    BlockDiscontinuity(temp_storage).FlagHeads(head_flags, thread_data, cub::Inequality());

Suppose the set of input thread_data across the block of threads is { [0,0,1,1], [1,1,1,1], [2,3,3,3], [3,4,4,4], ... }. The corresponding output head_flags in those threads will be { [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }.

Re-using dynamically allocating shared memory

The examples/block/example_block_reduce_dyn_smem.cu example illustrates usage of dynamically shared memory with BlockReduce and how to re-purpose the same memory region. This example can be easily adapted to the storage required by BlockDiscontinuity.

Template Parameters
  • T – The data type to be flagged.

  • BLOCK_DIM_X – The thread block length in threads along the X dimension

  • BLOCK_DIM_Y[optional] The thread block length in threads along the Y dimension (default: 1)

  • BLOCK_DIM_Z[optional] The thread block length in threads along the Z dimension (default: 1)

  • LEGACY_PTX_ARCH[optional] Unused

Collective constructors

inline BlockDiscontinuity()

Collective constructor using a private static allocation of shared memory as temporary storage.

inline BlockDiscontinuity(TempStorage &temp_storage)

Collective constructor using the specified memory allocation as temporary storage.

Parameters

temp_storage[in] Reference to memory allocation having layout type TempStorage

Head flag operations

template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
inline void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op)

Sets head flags indicating discontinuities between items partitioned across the thread block, for which the first item has no reference and is always flagged.

  • The flag head_flags[i] is set for item input[i] when flag_op(previous-item, input[i]) returns true (where previous-item is either the preceding item in the same thread or the last item in the previous thread).

  • For thread0, item input[0] is always flagged.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • Efficiency is increased with increased granularity ITEMS_PER_THREAD. Performance is also typically increased until the additional register pressure or shared memory allocation size causes SM occupancy to fall too low. Consider variants of cub::BlockLoad for efficiently gathering a blocked arrangement of elements across threads.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the head-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_discontinuity.cuh>

__global__ void ExampleKernel(...)
{
    // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int
    using BlockDiscontinuity = cub::BlockDiscontinuity<int, 128>;

    // Allocate shared memory for BlockDiscontinuity
    __shared__ typename BlockDiscontinuity::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Collectively compute head flags for discontinuities in the segment
    int head_flags[4];
    BlockDiscontinuity(temp_storage).FlagHeads(head_flags, thread_data, cub::Inequality());

Suppose the set of input thread_data across the block of threads is { [0,0,1,1], [1,1,1,1], [2,3,3,3], [3,4,4,4], ... }. The corresponding output head_flags in those threads will be { [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }.

Template Parameters
  • ITEMS_PER_THREAD[inferred] The number of consecutive items partitioned onto each thread

  • FlagT[inferred] The flag type (must be an integer type)

  • FlagOp[inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning true if a discontinuity exists between a and b, otherwise false. b_index is the rank of b in the aggregate tile of data.

Parameters
  • head_flags[out] Calling thread’s discontinuity head_flags

  • input[in] Calling thread’s input items

  • flag_op[in] Binary boolean flag predicate

template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
inline void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op, T tile_predecessor_item)

Sets head flags indicating discontinuities between items partitioned across the thread block.

  • The flag head_flags[i] is set for item input[i] when flag_op(previous-item, input[i]) returns true (where previous-item is either the preceding item in the same thread or the last item in the previous thread).

  • For thread0, item input[0] is compared against tile_predecessor_item.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • Efficiency is increased with increased granularity ITEMS_PER_THREAD. Performance is also typically increased until the additional register pressure or shared memory allocation size causes SM occupancy to fall too low. Consider variants of cub::BlockLoad for efficiently gathering a blocked arrangement of elements across threads.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the head-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_discontinuity.cuh>

__global__ void ExampleKernel(...)
{
    // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int
    using BlockDiscontinuity = cub::BlockDiscontinuity<int, 128>;

    // Allocate shared memory for BlockDiscontinuity
    __shared__ typename BlockDiscontinuity::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Have thread0 obtain the predecessor item for the entire tile
    int tile_predecessor_item;
    if (threadIdx.x == 0) tile_predecessor_item == ...

    // Collectively compute head flags for discontinuities in the segment
    int head_flags[4];
    BlockDiscontinuity(temp_storage).FlagHeads(
        head_flags, thread_data, cub::Inequality(), tile_predecessor_item);

Suppose the set of input thread_data across the block of threads is { [0,0,1,1], [1,1,1,1], [2,3,3,3], [3,4,4,4], ... }, and that tile_predecessor_item is 0. The corresponding output head_flags in those threads will be { [0,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }.

Template Parameters
  • ITEMS_PER_THREAD[inferred] The number of consecutive items partitioned onto each thread.

  • FlagT[inferred] The flag type (must be an integer type)

  • FlagOp[inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning true if a discontinuity exists between a and b, otherwise false. b_index is the rank of b in the aggregate tile of data.

Parameters
  • head_flags[out] Calling thread’s discontinuity head_flags

  • input[in] Calling thread’s input items

  • flag_op[in] Binary boolean flag predicate

  • tile_predecessor_item[in]

    thread0 only item with which to compare the first tile item (input[0] from thread0).

Tail flag operations

template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
inline void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op)

Sets tail flags indicating discontinuities between items partitioned across the thread block, for which the last item has no reference and is always flagged.

  • The flag tail_flags[i] is set for item input[i] when flag_op(input[i], next-item) returns true (where next-item is either the next item in the same thread or the first item in the next thread).

  • For threadBLOCK_THREADS - 1, item input[ITEMS_PER_THREAD - 1] is always flagged.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • Efficiency is increased with increased granularity ITEMS_PER_THREAD. Performance is also typically increased until the additional register pressure or shared memory allocation size causes SM occupancy to fall too low. Consider variants of cub::BlockLoad for efficiently gathering a blocked arrangement of elements across threads.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_discontinuity.cuh>

__global__ void ExampleKernel(...)
{
    // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int
    using BlockDiscontinuity = cub::BlockDiscontinuity<int, 128>;

    // Allocate shared memory for BlockDiscontinuity
    __shared__ typename BlockDiscontinuity::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Collectively compute tail flags for discontinuities in the segment
    int tail_flags[4];
    BlockDiscontinuity(temp_storage).FlagTails(tail_flags, thread_data, cub::Inequality());

Suppose the set of input thread_data across the block of threads is { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }. The corresponding output tail_flags in those threads will be { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,1] }.

Template Parameters
  • ITEMS_PER_THREAD[inferred] The number of consecutive items partitioned onto each thread.

  • FlagT[inferred] The flag type (must be an integer type)

  • FlagOp[inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning true if a discontinuity exists between a and b, otherwise false. b_index is the rank of b in the aggregate tile of data.

Parameters
  • tail_flags[out] Calling thread’s discontinuity tail_flags

  • input[in] Calling thread’s input items

  • flag_op[in] Binary boolean flag predicate

template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
inline void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op, T tile_successor_item)

Sets tail flags indicating discontinuities between items partitioned across the thread block.

  • The flag tail_flags[i] is set for item input[i] when flag_op(input[i], next-item) returns true (where next-item is either the next item in the same thread or the first item in the next thread).

  • For threadBLOCK_THREADS - 1, item input[ITEMS_PER_THREAD - 1] is compared against tile_successor_item.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • Efficiency is increased with increased granularity ITEMS_PER_THREAD. Performance is also typically increased until the additional register pressure or shared memory allocation size causes SM occupancy to fall too low. Consider variants of cub::BlockLoad for efficiently gathering a blocked arrangement of elements across threads.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_discontinuity.cuh>

__global__ void ExampleKernel(...)
{
    // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int
    using BlockDiscontinuity = cub::BlockDiscontinuity<int, 128>;

    // Allocate shared memory for BlockDiscontinuity
    __shared__ typename BlockDiscontinuity::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Have thread127 obtain the successor item for the entire tile
    int tile_successor_item;
    if (threadIdx.x == 127) tile_successor_item == ...

    // Collectively compute tail flags for discontinuities in the segment
    int tail_flags[4];
    BlockDiscontinuity(temp_storage).FlagTails(
        tail_flags, thread_data, cub::Inequality(), tile_successor_item);

Suppose the set of input thread_data across the block of threads is { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] } and that tile_successor_item is 125. The corresponding output tail_flags in those threads will be { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,0] }.

Template Parameters
  • ITEMS_PER_THREAD[inferred] The number of consecutive items partitioned onto each thread.

  • FlagT[inferred] The flag type (must be an integer type)

  • FlagOp[inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning true if a discontinuity exists between a and b, otherwise false. b_index is the rank of b in the aggregate tile of data.

Parameters
  • tail_flags[out] Calling thread’s discontinuity tail_flags

  • input[in] Calling thread’s input items

  • flag_op[in] Binary boolean flag predicate

  • tile_successor_item[in]

    threadBLOCK_THREADS - 1 only item with which to compare the last tile item (input[ITEMS_PER_THREAD - 1] from threadBLOCK_THREADS - 1).

Head & tail flag operations

template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
inline void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], FlagT (&tail_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op)

Sets both head and tail flags indicating discontinuities between items partitioned across the thread block.

  • The flag head_flags[i] is set for item input[i] when flag_op(previous-item, input[i]) returns true (where previous-item is either the preceding item in the same thread or the last item in the previous thread).

  • For thread0, item input[0] is always flagged.

  • The flag tail_flags[i] is set for item input[i] when flag_op(input[i], next-item) returns true (where next-item is either the next item in the same thread or the first item in the next thread).

  • For threadBLOCK_THREADS - 1, item input[ITEMS_PER_THREAD - 1] is always flagged.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • Efficiency is increased with increased granularity ITEMS_PER_THREAD. Performance is also typically increased until the additional register pressure or shared memory allocation size causes SM occupancy to fall too low. Consider variants of cub::BlockLoad for efficiently gathering a blocked arrangement of elements across threads.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the head- and tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_discontinuity.cuh>

__global__ void ExampleKernel(...)
{
    // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int
    using BlockDiscontinuity = cub::BlockDiscontinuity<int, 128>;

    // Allocate shared memory for BlockDiscontinuity
    __shared__ typename BlockDiscontinuity::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Collectively compute head and flags for discontinuities in the segment
    int head_flags[4];
    int tail_flags[4];
    BlockDiscontinuity(temp_storage).FlagHeadsAndTails(
        head_flags, tail_flags, thread_data, cub::Inequality());

Suppose the set of input thread_data across the block of threads is { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] } and that the tile_successor_item is 125. The corresponding output head_flags in those threads will be { [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. and the corresponding output tail_flags in those threads will be { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,1] }.

Template Parameters
  • ITEMS_PER_THREAD[inferred] The number of consecutive items partitioned onto each thread.

  • FlagT[inferred] The flag type (must be an integer type)

  • FlagOp[inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning true if a discontinuity exists between a and b, otherwise false. b_index is the rank of b in the aggregate tile of data.

Parameters
  • head_flags[out] Calling thread’s discontinuity head_flags

  • tail_flags[out] Calling thread’s discontinuity tail_flags

  • input[in] Calling thread’s input items

  • flag_op[in] Binary boolean flag predicate

template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
inline void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], FlagT (&tail_flags)[ITEMS_PER_THREAD], T tile_successor_item, T (&input)[ITEMS_PER_THREAD], FlagOp flag_op)

Sets both head and tail flags indicating discontinuities between items partitioned across the thread block.

  • The flag head_flags[i] is set for item input[i] when flag_op(previous-item, input[i]) returns true (where previous-item is either the preceding item in the same thread or the last item in the previous thread).

  • For thread0, item input[0] is always flagged.

  • The flag tail_flags[i] is set for item input[i] when flag_op(input[i], next-item) returns true (where next-item is either the next item in the same thread or the first item in the next thread).

  • For threadBLOCK_THREADS - 1, item input[ITEMS_PER_THREAD - 1] is compared against tile_predecessor_item.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • Efficiency is increased with increased granularity ITEMS_PER_THREAD. Performance is also typically increased until the additional register pressure or shared memory allocation size causes SM occupancy to fall too low. Consider variants of cub::BlockLoad for efficiently gathering a blocked arrangement of elements across threads.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the head- and tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_discontinuity.cuh>

__global__ void ExampleKernel(...)
{
    // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int
    using BlockDiscontinuity = cub::BlockDiscontinuity<int, 128>;

    // Allocate shared memory for BlockDiscontinuity
    __shared__ typename BlockDiscontinuity::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Have thread127 obtain the successor item for the entire tile
    int tile_successor_item;
    if (threadIdx.x == 127) tile_successor_item == ...

    // Collectively compute head and flags for discontinuities in the segment
    int head_flags[4];
    int tail_flags[4];
    BlockDiscontinuity(temp_storage).FlagHeadsAndTails(
        head_flags, tail_flags, tile_successor_item, thread_data, cub::Inequality());

Suppose the set of input thread_data across the block of threads is { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] } and that the tile_successor_item is 125. The corresponding output head_flags in those threads will be { [1,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. and the corresponding output tail_flags in those threads will be { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,0] }.

Template Parameters
  • ITEMS_PER_THREAD[inferred] The number of consecutive items partitioned onto each thread.

  • FlagT[inferred] The flag type (must be an integer type)

  • FlagOp[inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning true if a discontinuity exists between a and b, otherwise false. b_index is the rank of b in the aggregate tile of data.

Parameters
  • head_flags[out] Calling thread’s discontinuity head_flags

  • tail_flags[out] Calling thread’s discontinuity tail_flags

  • tile_successor_item[in]

    threadBLOCK_THREADS - 1 only item with which to compare the last tile item (input[ITEMS_PER_THREAD - 1] from threadBLOCK_THREADS - 1).

  • input[in] Calling thread’s input items

  • flag_op[in] Binary boolean flag predicate

template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
inline void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], T tile_predecessor_item, FlagT (&tail_flags)[ITEMS_PER_THREAD], T (&input)[ITEMS_PER_THREAD], FlagOp flag_op)

Sets both head and tail flags indicating discontinuities between items partitioned across the thread block.

  • The flag head_flags[i] is set for item input[i] when flag_op(previous-item, input[i]) returns true (where previous-item is either the preceding item in the same thread or the last item in the previous thread).

  • For thread0, item input[0] is compared against tile_predecessor_item.

  • The flag tail_flags[i] is set for item input[i] when flag_op(input[i], next-item) returns true (where next-item is either the next item in the same thread or the first item in the next thread).

  • For threadBLOCK_THREADS - 1, item input[ITEMS_PER_THREAD - 1] is always flagged.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • Efficiency is increased with increased granularity ITEMS_PER_THREAD. Performance is also typically increased until the additional register pressure or shared memory allocation size causes SM occupancy to fall too low. Consider variants of cub::BlockLoad for efficiently gathering a blocked arrangement of elements across threads.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the head- and tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_discontinuity.cuh>

__global__ void ExampleKernel(...)
{
    // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int
    using BlockDiscontinuity = cub::BlockDiscontinuity<int, 128>;

    // Allocate shared memory for BlockDiscontinuity
    __shared__ typename BlockDiscontinuity::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Have thread0 obtain the predecessor item for the entire tile
    int tile_predecessor_item;
    if (threadIdx.x == 0) tile_predecessor_item == ...

    // Have thread127 obtain the successor item for the entire tile
    int tile_successor_item;
    if (threadIdx.x == 127) tile_successor_item == ...

    // Collectively compute head and flags for discontinuities in the segment
    int head_flags[4];
    int tail_flags[4];
    BlockDiscontinuity(temp_storage).FlagHeadsAndTails(
        head_flags, tile_predecessor_item, tail_flags, tile_successor_item,
        thread_data, cub::Inequality());

Suppose the set of input thread_data across the block of threads is { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }, that the tile_predecessor_item is 0, and that the tile_successor_item is 125. The corresponding output head_flags in those threads will be { [0,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }, and the corresponding output tail_flags in those threads will be { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,1] }.

Template Parameters
  • ITEMS_PER_THREAD[inferred] The number of consecutive items partitioned onto each thread.

  • FlagT[inferred] The flag type (must be an integer type)

  • FlagOp[inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning true if a discontinuity exists between a and b, otherwise false. b_index is the rank of b in the aggregate tile of data.

Parameters
  • head_flags[out] Calling thread’s discontinuity head_flags

  • tile_predecessor_item[in]

    thread0 only item with which to compare the first tile item (input[0] from thread0).

  • tail_flags[out] Calling thread’s discontinuity tail_flags

  • input[in] Calling thread’s input items

  • flag_op[in] Binary boolean flag predicate

template<int ITEMS_PER_THREAD, typename FlagT, typename FlagOp>
inline void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], T tile_predecessor_item, FlagT (&tail_flags)[ITEMS_PER_THREAD], T tile_successor_item, T (&input)[ITEMS_PER_THREAD], FlagOp flag_op)

Sets both head and tail flags indicating discontinuities between items partitioned across the thread block.

  • The flag head_flags[i] is set for item input[i] when flag_op(previous-item, input[i]) returns true (where previous-item is either the preceding item in the same thread or the last item in the previous thread).

  • For thread0, item input[0] is compared against tile_predecessor_item.

  • The flag tail_flags[i] is set for item input[i] when flag_op(input[i], next-item) returns true (where next-item is either the next item in the same thread or the first item in the next thread).

  • For threadBLOCK_THREADS - 1, item input[ITEMS_PER_THREAD - 1] is compared against tile_successor_item.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • Efficiency is increased with increased granularity ITEMS_PER_THREAD. Performance is also typically increased until the additional register pressure or shared memory allocation size causes SM occupancy to fall too low. Consider variants of cub::BlockLoad for efficiently gathering a blocked arrangement of elements across threads.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the head- and tail-flagging of 512 integer items that are partitioned in a blocked arrangement across 128 threads where each thread owns 4 consecutive items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_discontinuity.cuh>

__global__ void ExampleKernel(...)
{
    // Specialize BlockDiscontinuity for a 1D block of 128 threads of type int
    using BlockDiscontinuity = cub::BlockDiscontinuity<int, 128>;

    // Allocate shared memory for BlockDiscontinuity
    __shared__ typename BlockDiscontinuity::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Have thread0 obtain the predecessor item for the entire tile
    int tile_predecessor_item;
    if (threadIdx.x == 0) tile_predecessor_item == ...

    // Have thread127 obtain the successor item for the entire tile
    int tile_successor_item;
    if (threadIdx.x == 127) tile_successor_item == ...

    // Collectively compute head and flags for discontinuities in the segment
    int head_flags[4];
    int tail_flags[4];
    BlockDiscontinuity(temp_storage).FlagHeadsAndTails(
        head_flags, tile_predecessor_item, tail_flags, tile_successor_item,
        thread_data, cub::Inequality());

Suppose the set of input thread_data across the block of threads is { [0,0,1,1], [1,1,1,1], [2,3,3,3], ..., [124,125,125,125] }, that the tile_predecessor_item is 0, and that the tile_successor_item is 125. The corresponding output head_flags in those threads will be { [0,0,1,0], [0,0,0,0], [1,1,0,0], [0,1,0,0], ... }. and the corresponding output tail_flags in those threads will be { [0,1,0,0], [0,0,0,1], [1,0,0,...], ..., [1,0,0,0] }.

Template Parameters
  • ITEMS_PER_THREAD[inferred] The number of consecutive items partitioned onto each thread.

  • FlagT[inferred] The flag type (must be an integer type)

  • FlagOp[inferred] Binary predicate functor type having member T operator()(const T &a, const T &b) or member T operator()(const T &a, const T &b, unsigned int b_index), and returning true if a discontinuity exists between a and b, otherwise false. b_index is the rank of b in the aggregate tile of data.

Parameters
  • head_flags[out] Calling thread’s discontinuity head_flags

  • tile_predecessor_item[in]

    thread0 only item with which to compare the first tile item (input[0] from thread0).

  • tail_flags[out] Calling thread’s discontinuity tail_flags

  • tile_successor_item[in]

    threadBLOCK_THREADS - 1 only item with which to compare the last tile item (input[ITEMS_PER_THREAD - 1] from threadBLOCK_THREADS - 1).

  • input[in] Calling thread’s input items

  • flag_op[in] Binary boolean flag predicate

struct TempStorage : public Uninitialized<_TempStorage>

The operations exposed by BlockDiscontinuity require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the __shared__ keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or union’d with other storage allocation types to facilitate memory reuse.