cub::BlockExchange

Defined in cub/block/block_exchange.cuh

template<typename T, int BLOCK_DIM_X, int ITEMS_PER_THREAD, bool WARP_TIME_SLICING = false, int BLOCK_DIM_Y = 1, int BLOCK_DIM_Z = 1, int LEGACY_PTX_ARCH = 0>
class BlockExchange

The BlockExchange class provides collective methods for rearranging data partitioned across a CUDA thread block.

Overview

  • It is commonplace for blocks of threads to rearrange data items between threads. For example, the device-accessible memory subsystem prefers access patterns where data items are “striped” across threads (where consecutive threads access consecutive items), yet most block-wide operations prefer a “blocked” partitioning of items across threads (where consecutive items belong to a single thread).

  • BlockExchange supports the following types of data exchanges:

  • For multi-dimensional blocks, threads are linearly ranked in row-major order.

A Simple Example

Every thread in the block uses the BlockExchange class by first specializing the BlockExchange type, then instantiating an instance with parameters for communication, and finally invoking one or more collective member functions.

The code snippet below illustrates the conversion from a “blocked” to a “striped” arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_exchange.cuh>

__global__ void ExampleKernel(int *d_data, ...)
{
    // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
    using BlockExchange = cub::BlockExchange<int, 128, 4>;

    // Allocate shared memory for BlockExchange
    __shared__ typename BlockExchange::TempStorage temp_storage;

    // Load a tile of data striped across threads
    int thread_data[4];
    cub::LoadDirectStriped<128>(threadIdx.x, d_data, thread_data);

    // Collectively exchange data into a blocked arrangement across threads
    BlockExchange(temp_storage).StripedToBlocked(thread_data);

Suppose the set of striped input thread_data across the block of threads is { [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] }. The corresponding output thread_data in those threads will be { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }.

Performance Considerations

  • Proper device-specific padding ensures zero bank conflicts for most types.

Re-using dynamically allocating shared memory

The block/example_block_reduce_dyn_smem.cu example illustrates usage of dynamically shared memory with BlockReduce and how to re-purpose the same memory region. This example can be easily adapted to the storage required by BlockExchange.

Template Parameters
  • T – The data type to be exchanged

  • BLOCK_DIM_X – The thread block length in threads along the X dimension

  • ITEMS_PER_THREAD – The number of items partitioned onto each thread.

  • WARP_TIME_SLICING[optional] When true, only use enough shared memory for a single warp’s worth of tile data, time-slicing the block-wide exchange over multiple synchronized rounds. Yields a smaller memory footprint at the expense of decreased parallelism. (Default: false)

  • BLOCK_DIM_Y[optional] The thread block length in threads along the Y dimension (default: 1)

  • BLOCK_DIM_Z[optional] The thread block length in threads along the Z dimension (default: 1)

  • LEGACY_PTX_ARCH[optional] Unused.

Collective constructors

inline BlockExchange()

Collective constructor using a private static allocation of shared memory as temporary storage.

inline BlockExchange(TempStorage &temp_storage)

Collective constructor using the specified memory allocation as temporary storage.

Parameters

temp_storage[in] Reference to memory allocation having layout type TempStorage

Structured exchanges

template<typename OutputT>
inline void StripedToBlocked(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD])

Transposes data items from striped arrangement to blocked arrangement.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the conversion from a “striped” to a “blocked” arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_exchange.cuh>

__global__ void ExampleKernel(int *d_data, ...)
{
    // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
    using BlockExchange = cub::BlockExchange<int, 128, 4>;

    // Allocate shared memory for BlockExchange
    __shared__ typename BlockExchange::TempStorage temp_storage;

    // Load a tile of ordered data into a striped arrangement across block threads
    int thread_data[4];
    cub::LoadDirectStriped<128>(threadIdx.x, d_data, thread_data);

    // Collectively exchange data into a blocked arrangement across threads
    BlockExchange(temp_storage).StripedToBlocked(thread_data, thread_data);

Suppose the set of striped input thread_data across the block of threads is { [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] } after loading from device-accessible memory. The corresponding output thread_data in those threads will be { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }.

Parameters
  • input_items[in] Items to exchange, converting between striped and blocked arrangements.

  • output_items[out] Items from exchange, converting between striped and blocked arrangements.

template<typename OutputT>
inline void BlockedToStriped(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD])

Transposes data items from blocked arrangement to striped arrangement.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the conversion from a “blocked” to a “striped” arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_exchange.cuh>

__global__ void ExampleKernel(int *d_data, ...)
{
    // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
    using BlockExchange = cub::BlockExchange<int, 128, 4>;

    // Allocate shared memory for BlockExchange
    __shared__ typename BlockExchange::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Collectively exchange data into a striped arrangement across threads
    BlockExchange(temp_storage).BlockedToStriped(thread_data, thread_data);

    // Store data striped across block threads into an ordered tile
    cub::StoreDirectStriped<STORE_DEFAULT, 128>(threadIdx.x, d_data, thread_data);

Suppose the set of blocked input thread_data across the block of threads is { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. The corresponding output thread_data in those threads will be { [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] } in preparation for storing to device-accessible memory.

Parameters
  • input_items[in] Items to exchange, converting between striped and blocked arrangements.

  • output_items[out] Items from exchange, converting between striped and blocked arrangements.

template<typename OutputT>
inline void WarpStripedToBlocked(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD])

Transposes data items from warp-striped arrangement to blocked arrangement.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the conversion from a “warp-striped” to a “blocked” arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_exchange.cuh>

__global__ void ExampleKernel(int *d_data, ...)
{
    // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
    using BlockExchange = cub::BlockExchange<int, 128, 4>;

    // Allocate shared memory for BlockExchange
    __shared__ typename BlockExchange::TempStorage temp_storage;

    // Load a tile of ordered data into a warp-striped arrangement across warp threads
    int thread_data[4];
    cub::LoadSWarptriped<LOAD_DEFAULT>(threadIdx.x, d_data, thread_data);

    // Collectively exchange data into a blocked arrangement across threads
    BlockExchange(temp_storage).WarpStripedToBlocked(thread_data);

Suppose the set of warp-striped input thread_data across the block of threads is { [0,32,64,96], [1,33,65,97], [2,34,66,98], ..., [415,447,479,511] } after loading from device-accessible memory. (The first 128 items are striped across the first warp of 32 threads, the second 128 items are striped across the second warp, etc.) The corresponding output thread_data in those threads will be { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }.

Parameters
  • input_items[in] Items to exchange, converting between striped and blocked arrangements.

  • output_items[out] Items from exchange, converting between striped and blocked arrangements.

template<typename OutputT>
inline void BlockedToWarpStriped(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD])

Transposes data items from blocked arrangement to warp-striped arrangement.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the conversion from a “blocked” to a “warp-striped” arrangement of 512 integer items partitioned across 128 threads where each thread owns 4 items.

#include <cub/cub.cuh>   // or equivalently <cub/block/block_exchange.cuh>

__global__ void ExampleKernel(int *d_data, ...)
{
    // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
    using BlockExchange = cub::BlockExchange<int, 128, 4>;

    // Allocate shared memory for BlockExchange
    __shared__ typename BlockExchange::TempStorage temp_storage;

    // Obtain a segment of consecutive items that are blocked across threads
    int thread_data[4];
    ...

    // Collectively exchange data into a warp-striped arrangement across threads
    BlockExchange(temp_storage).BlockedToWarpStriped(thread_data, thread_data);

    // Store data striped across warp threads into an ordered tile
    cub::StoreDirectStriped<STORE_DEFAULT, 128>(threadIdx.x, d_data, thread_data);

Suppose the set of blocked input thread_data across the block of threads is { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. The corresponding output thread_data in those threads will be { [0,32,64,96], [1,33,65,97], [2,34,66,98], ..., [415,447,479,511] } in preparation for storing to device-accessible memory. (The first 128 items are striped across the first warp of 32 threads, the second 128 items are striped across the second warp, etc.)

Parameters
  • input_items[in] Items to exchange, converting between striped and blocked arrangements.

  • output_items[out] Items from exchange, converting between striped and blocked arrangements.

Scatter exchanges

template<typename OutputT, typename OffsetT>
inline void ScatterToBlocked(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD], OffsetT (&ranks)[ITEMS_PER_THREAD])

Exchanges data items annotated by rank into blocked arrangement.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Template Parameters

OffsetT[inferred] Signed integer type for local offsets

Parameters
  • input_items[in] Items to exchange, converting between striped and blocked arrangements.

  • output_items[out] Items from exchange, converting between striped and blocked arrangements.

  • ranks[in] Corresponding scatter ranks

template<typename OutputT, typename OffsetT>
inline void ScatterToStriped(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD], OffsetT (&ranks)[ITEMS_PER_THREAD])

Exchanges data items annotated by rank into striped arrangement.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Template Parameters

OffsetT[inferred] Signed integer type for local offsets

Parameters
  • input_items[in] Items to exchange, converting between striped and blocked arrangements.

  • output_items[out] Items from exchange, converting between striped and blocked arrangements.

  • ranks[in] Corresponding scatter ranks

template<typename OutputT, typename OffsetT>
inline void ScatterToStripedGuarded(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD], OffsetT (&ranks)[ITEMS_PER_THREAD])

Exchanges data items annotated by rank into striped arrangement. Items with rank -1 are not exchanged.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Template Parameters

OffsetT[inferred] Signed integer type for local offsets

Parameters
  • input_items[in] Items to exchange, converting between striped and blocked arrangements.

  • output_items[out] Items from exchange, converting between striped and blocked arrangements.

  • ranks[in] Corresponding scatter ranks

template<typename OutputT, typename OffsetT, typename ValidFlag>
inline void ScatterToStripedFlagged(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD], OffsetT (&ranks)[ITEMS_PER_THREAD], ValidFlag (&is_valid)[ITEMS_PER_THREAD])

Exchanges valid data items annotated by rank into striped arrangement.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Template Parameters
  • OffsetT[inferred] Signed integer type for local offsets

  • ValidFlag[inferred] FlagT type denoting which items are valid

Parameters
  • input_items[in] Items to exchange, converting between striped and blocked arrangements.

  • output_items[out] Items from exchange, converting between striped and blocked arrangements.

  • ranks[in] Corresponding scatter ranks

  • is_valid[in] Corresponding flag denoting item validity

Public Types

using TempStorage = Uninitialized<_TempStorage>

The operations exposed by BlockExchange require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the __shared__ keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or union’d with other storage allocation types to facilitate memory reuse.