cub::BlockLoad

Defined in cub/block/block_load.cuh

template<typename T, int BLOCK_DIM_X, int ITEMS_PER_THREAD, BlockLoadAlgorithm ALGORITHM = BLOCK_LOAD_DIRECT, int BLOCK_DIM_Y = 1, int BLOCK_DIM_Z = 1, int LEGACY_PTX_ARCH = 0>
class BlockLoad

The BlockLoad class provides collective data movement methods for loading a linear segment of items from memory into a blocked arrangement across a CUDA thread block.

Overview

A Simple Example

Every thread in the block uses the BlockLoad class by first specializing the BlockLoad type, then instantiating an instance with parameters for communication, and finally invoking one or more collective member functions.

The code snippet below illustrates the loading of a linear segment of 512 integers into a “blocked” arrangement across 128 threads where each thread owns 4 consecutive items. The load is specialized for BLOCK_LOAD_WARP_TRANSPOSE, meaning memory references are efficiently coalesced using a warp-striped access pattern (after which items are locally reordered among threads).

#include <cub/cub.cuh>   // or equivalently <cub/block/block_load.cuh>

__global__ void ExampleKernel(int *d_data, ...)
{
    // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each
    using BlockLoad = cub::BlockLoad<int, 128, 4, BLOCK_LOAD_WARP_TRANSPOSE>;

    // Allocate shared memory for BlockLoad
    __shared__ typename BlockLoad::TempStorage temp_storage;

    // Load a segment of consecutive items that are blocked across threads
    int thread_data[4];
    BlockLoad(temp_storage).Load(d_data, thread_data);

Suppose the input d_data is 0, 1, 2, 3, 4, 5, .... The set of thread_data across the block of threads in those threads will be { [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }.

Re-using dynamically allocating shared memory

The block/example_block_reduce_dyn_smem.cu example illustrates usage of dynamically shared memory with BlockReduce and how to re-purpose the same memory region. This example can be easily adapted to the storage required by BlockLoad.

Template Parameters

T

Collective constructors

inline BlockLoad()

Collective constructor using a private static allocation of shared memory as temporary storage.

inline BlockLoad(TempStorage &temp_storage)

Collective constructor using the specified memory allocation as temporary storage.

Parameters

temp_storage[in] Reference to memory allocation having layout type TempStorage

Data movement

template<typename RandomAccessIterator>
inline void Load(RandomAccessIterator block_src_it, T (&dst_items)[ITEMS_PER_THREAD])

Load a linear segment of items from memory.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the loading of a linear segment of 512 integers into a “blocked” arrangement across 128 threads where each thread owns 4 consecutive items. The load is specialized for BLOCK_LOAD_WARP_TRANSPOSE, meaning memory references are efficiently coalesced using a warp-striped access pattern (after which items are locally reordered among threads).

#include <cub/cub.cuh>   // or equivalently <cub/block/block_load.cuh>

__global__ void ExampleKernel(int *d_data, ...)
{
    // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each
    using BlockLoad = cub::BlockLoad<int, 128, 4, BLOCK_LOAD_WARP_TRANSPOSE>;

    // Allocate shared memory for BlockLoad
    __shared__ typename BlockLoad::TempStorage temp_storage;

    // Load a segment of consecutive items that are blocked across threads
    int thread_data[4];
    BlockLoad(temp_storage).Load(d_data, thread_data);

Suppose the input d_data is 0, 1, 2, 3, 4, 5, .... The set of thread_data across the block of threads in those threads will be { [0,1,2,3], [4,5,6,7], ..., [508,509,510,511] }.

Parameters
  • block_src_it[in] The thread block’s base iterator for loading from

  • dst_items[out] Destination to load data into

template<typename RandomAccessIterator>
inline void Load(RandomAccessIterator block_src_it, T (&dst_items)[ITEMS_PER_THREAD], int block_items_end)

Load a linear segment of items from memory, guarded by range.

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the guarded loading of a linear segment of 512 integers into a “blocked” arrangement across 128 threads where each thread owns 4 consecutive items. The load is specialized for BLOCK_LOAD_WARP_TRANSPOSE, meaning memory references are efficiently coalesced using a warp-striped access pattern (after which items are locally reordered among threads).

#include <cub/cub.cuh>   // or equivalently <cub/block/block_load.cuh>

__global__ void ExampleKernel(int *d_data, int block_items_end, ...)
{
    // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each
    using BlockLoad = cub::BlockLoad<int, 128, 4, BLOCK_LOAD_WARP_TRANSPOSE>;

    // Allocate shared memory for BlockLoad
    __shared__ typename BlockLoad::TempStorage temp_storage;

    // Load a segment of consecutive items that are blocked across threads
    int thread_data[4];
    BlockLoad(temp_storage).Load(d_data, thread_data, block_items_end);

Suppose the input d_data is 0, 1, 2, 3, 4, 5, 6... and block_items_end is 5. The set of thread_data across the block of threads in those threads will be { [0,1,2,3], [4,?,?,?], ..., [?,?,?,?] }, with only the first two threads being unmasked to load portions of valid data (and other items remaining unassigned).

Parameters
  • block_src_it[in] The thread block’s base iterator for loading from

  • dst_items[out] Destination to load data into

  • block_items_end[in] Number of valid items to load

template<typename RandomAccessIterator, typename DefaultT>
inline void Load(RandomAccessIterator block_src_it, T (&dst_items)[ITEMS_PER_THREAD], int block_items_end, DefaultT oob_default)

Load a linear segment of items from memory, guarded by range, with a fall-back assignment of out-of-bound elements

  • Assumes a blocked arrangement of (block-threads * items-per-thread) items across the thread block, where threadi owns the ith range of items-per-thread contiguous items. For multi-dimensional thread blocks, a row-major thread ordering is assumed.

  • A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage (e.g., temp_storage) is to be reused or repurposed.

Snippet

The code snippet below illustrates the guarded loading of a linear segment of 512 integers into a “blocked” arrangement across 128 threads where each thread owns 4 consecutive items. The load is specialized for BLOCK_LOAD_WARP_TRANSPOSE, meaning memory references are efficiently coalesced using a warp-striped access pattern (after which items are locally reordered among threads).

#include <cub/cub.cuh>   // or equivalently <cub/block/block_load.cuh>

__global__ void ExampleKernel(int *d_data, int block_items_end, ...)
{
    // Specialize BlockLoad for a 1D block of 128 threads owning 4 integer items each
    using BlockLoad = cub::BlockLoad<int, 128, 4, BLOCK_LOAD_WARP_TRANSPOSE>;

    // Allocate shared memory for BlockLoad
    __shared__ typename BlockLoad::TempStorage temp_storage;

    // Load a segment of consecutive items that are blocked across threads
    int thread_data[4];
    BlockLoad(temp_storage).Load(d_data, thread_data, block_items_end, -1);

Suppose the input d_data is 0, 1, 2, 3, 4, 5, 6..., block_items_end is 5, and the out-of-bounds default is -1. The set of thread_data across the block of threads in those threads will be { [0,1,2,3], [4,-1,-1,-1], ..., [-1,-1,-1,-1] }, with only the first two threads being unmasked to load portions of valid data (and other items are assigned -1)

Parameters
  • block_src_it[in] The thread block’s base iterator for loading from

  • dst_items[out] Destination to load data into

  • block_items_end[in] Number of valid items to load

  • oob_default[in] Default value to assign out-of-bound items

Public Types

using TempStorage = Uninitialized<_TempStorage>

The operations exposed by BlockLoad require a temporary memory allocation of this nested type for thread communication. This opaque storage can be allocated directly using the __shared__ keyword. Alternatively, it can be aliased to externally allocated memory (shared or global) or union’d with other storage allocation types to facilitate memory reuse.