cub/block/block_store.cuh

File members: cub/block/block_store.cuh

/******************************************************************************
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/block/block_exchange.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>

CUB_NAMESPACE_BEGIN

template <typename T, int ITEMS_PER_THREAD, typename OutputIteratorT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
StoreDirectBlocked(int linear_tid, OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
{
  OutputIteratorT thread_itr = block_itr + (linear_tid * ITEMS_PER_THREAD);

// Store directly in thread-blocked order
#pragma unroll
  for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
  {
    thread_itr[ITEM] = items[ITEM];
  }
}

template <typename T, int ITEMS_PER_THREAD, typename OutputIteratorT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
StoreDirectBlocked(int linear_tid, OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
{
  OutputIteratorT thread_itr = block_itr + (linear_tid * ITEMS_PER_THREAD);

// Store directly in thread-blocked order
#pragma unroll
  for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
  {
    if (ITEM + (linear_tid * ITEMS_PER_THREAD) < valid_items)
    {
      thread_itr[ITEM] = items[ITEM];
    }
  }
}

template <typename T, int ITEMS_PER_THREAD>
_CCCL_DEVICE _CCCL_FORCEINLINE void
StoreDirectBlockedVectorized(int linear_tid, T* block_ptr, T (&items)[ITEMS_PER_THREAD])
{
  enum
  {
    // Maximum CUDA vector size is 4 elements
    MAX_VEC_SIZE = CUB_MIN(4, ITEMS_PER_THREAD),

    // Vector size must be a power of two and an even divisor of the items per thread
    VEC_SIZE =
      ((((MAX_VEC_SIZE - 1) & MAX_VEC_SIZE) == 0) && ((ITEMS_PER_THREAD % MAX_VEC_SIZE) == 0)) ? MAX_VEC_SIZE : 1,

    VECTORS_PER_THREAD = ITEMS_PER_THREAD / VEC_SIZE,
  };

  // Vector type
  using Vector = typename CubVector<T, VEC_SIZE>::Type;

  // Alias global pointer
  Vector* block_ptr_vectors = reinterpret_cast<Vector*>(const_cast<T*>(block_ptr));

  // Alias pointers (use "raw" array here which should get optimized away to prevent conservative PTXAS lmem spilling)
  Vector raw_vector[VECTORS_PER_THREAD];
  T* raw_items = reinterpret_cast<T*>(raw_vector);

// Copy
#pragma unroll
  for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
  {
    raw_items[ITEM] = items[ITEM];
  }

  // Direct-store using vector types
  StoreDirectBlocked(linear_tid, block_ptr_vectors, raw_vector);
}

template <int BLOCK_THREADS, typename T, int ITEMS_PER_THREAD, typename OutputIteratorT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
StoreDirectStriped(int linear_tid, OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
{
  OutputIteratorT thread_itr = block_itr + linear_tid;

// Store directly in striped order
#pragma unroll
  for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
  {
    thread_itr[(ITEM * BLOCK_THREADS)] = items[ITEM];
  }
}

template <int BLOCK_THREADS, typename T, int ITEMS_PER_THREAD, typename OutputIteratorT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
StoreDirectStriped(int linear_tid, OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
{
  OutputIteratorT thread_itr = block_itr + linear_tid;

// Store directly in striped order
#pragma unroll
  for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
  {
    if ((ITEM * BLOCK_THREADS) + linear_tid < valid_items)
    {
      thread_itr[(ITEM * BLOCK_THREADS)] = items[ITEM];
    }
  }
}

template <typename T, int ITEMS_PER_THREAD, typename OutputIteratorT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
StoreDirectWarpStriped(int linear_tid, OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
{
  int tid         = linear_tid & (CUB_PTX_WARP_THREADS - 1);
  int wid         = linear_tid >> CUB_PTX_LOG_WARP_THREADS;
  int warp_offset = wid * CUB_PTX_WARP_THREADS * ITEMS_PER_THREAD;

  OutputIteratorT thread_itr = block_itr + warp_offset + tid;

// Store directly in warp-striped order
#pragma unroll
  for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
  {
    thread_itr[(ITEM * CUB_PTX_WARP_THREADS)] = items[ITEM];
  }
}

template <typename T, int ITEMS_PER_THREAD, typename OutputIteratorT>
_CCCL_DEVICE _CCCL_FORCEINLINE void
StoreDirectWarpStriped(int linear_tid, OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
{
  int tid         = linear_tid & (CUB_PTX_WARP_THREADS - 1);
  int wid         = linear_tid >> CUB_PTX_LOG_WARP_THREADS;
  int warp_offset = wid * CUB_PTX_WARP_THREADS * ITEMS_PER_THREAD;

  OutputIteratorT thread_itr = block_itr + warp_offset + tid;

// Store directly in warp-striped order
#pragma unroll
  for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
  {
    if (warp_offset + tid + (ITEM * CUB_PTX_WARP_THREADS) < valid_items)
    {
      thread_itr[(ITEM * CUB_PTX_WARP_THREADS)] = items[ITEM];
    }
  }
}

//-----------------------------------------------------------------------------
// Generic BlockStore abstraction
//-----------------------------------------------------------------------------

enum BlockStoreAlgorithm
{
  BLOCK_STORE_DIRECT,

  BLOCK_STORE_STRIPED,

  BLOCK_STORE_VECTORIZE,

  BLOCK_STORE_TRANSPOSE,

  BLOCK_STORE_WARP_TRANSPOSE,

  BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED,
};

template <typename T,
          int BLOCK_DIM_X,
          int ITEMS_PER_THREAD,
          BlockStoreAlgorithm ALGORITHM = BLOCK_STORE_DIRECT,
          int BLOCK_DIM_Y               = 1,
          int BLOCK_DIM_Z               = 1,
          int LEGACY_PTX_ARCH           = 0>
class BlockStore
{
private:
  enum
  {
    BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
  };

  template <BlockStoreAlgorithm _POLICY, int DUMMY>
  struct StoreInternal;

  template <int DUMMY>
  struct StoreInternal<BLOCK_STORE_DIRECT, DUMMY>
  {
    using TempStorage = NullType;

    int linear_tid;

    _CCCL_DEVICE _CCCL_FORCEINLINE StoreInternal(TempStorage& /*temp_storage*/, int linear_tid)
        : linear_tid(linear_tid)
    {}

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
    {
      StoreDirectBlocked(linear_tid, block_itr, items);
    }

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
    {
      StoreDirectBlocked(linear_tid, block_itr, items, valid_items);
    }
  };

  template <int DUMMY>
  struct StoreInternal<BLOCK_STORE_STRIPED, DUMMY>
  {
    using TempStorage = NullType;

    int linear_tid;

    _CCCL_DEVICE _CCCL_FORCEINLINE StoreInternal(TempStorage& /*temp_storage*/, int linear_tid)
        : linear_tid(linear_tid)
    {}

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
    {
      StoreDirectStriped<BLOCK_THREADS>(linear_tid, block_itr, items);
    }

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
    {
      StoreDirectStriped<BLOCK_THREADS>(linear_tid, block_itr, items, valid_items);
    }
  };

  template <int DUMMY>
  struct StoreInternal<BLOCK_STORE_VECTORIZE, DUMMY>
  {
    using TempStorage = NullType;

    int linear_tid;

    _CCCL_DEVICE _CCCL_FORCEINLINE StoreInternal(TempStorage& /*temp_storage*/, int linear_tid)
        : linear_tid(linear_tid)
    {}

    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(T* block_ptr, T (&items)[ITEMS_PER_THREAD])
    {
      StoreDirectBlockedVectorized(linear_tid, block_ptr, items);
    }

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
    {
      StoreDirectBlocked(linear_tid, block_itr, items);
    }

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
    {
      StoreDirectBlocked(linear_tid, block_itr, items, valid_items);
    }
  };

  template <int DUMMY>
  struct StoreInternal<BLOCK_STORE_TRANSPOSE, DUMMY>
  {
    // BlockExchange utility type for keys
    using BlockExchange = BlockExchange<T, BLOCK_DIM_X, ITEMS_PER_THREAD, false, BLOCK_DIM_Y, BLOCK_DIM_Z>;

    struct _TempStorage : BlockExchange::TempStorage
    {
      volatile int valid_items;
    };

    struct TempStorage : Uninitialized<_TempStorage>
    {};

    _TempStorage& temp_storage;

    int linear_tid;

    _CCCL_DEVICE _CCCL_FORCEINLINE StoreInternal(TempStorage& temp_storage, int linear_tid)
        : temp_storage(temp_storage.Alias())
        , linear_tid(linear_tid)
    {}

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
    {
      BlockExchange(temp_storage).BlockedToStriped(items);
      StoreDirectStriped<BLOCK_THREADS>(linear_tid, block_itr, items);
    }

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
    {
      BlockExchange(temp_storage).BlockedToStriped(items);
      if (linear_tid == 0)
      {
        // Move through volatile smem as a workaround to prevent RF spilling on
        // subsequent loads
        temp_storage.valid_items = valid_items;
      }
      CTA_SYNC();
      StoreDirectStriped<BLOCK_THREADS>(linear_tid, block_itr, items, temp_storage.valid_items);
    }
  };

  template <int DUMMY>
  struct StoreInternal<BLOCK_STORE_WARP_TRANSPOSE, DUMMY>
  {
    enum
    {
      WARP_THREADS = CUB_WARP_THREADS(0)
    };

    // Assert BLOCK_THREADS must be a multiple of WARP_THREADS
    static_assert(int(BLOCK_THREADS) % int(WARP_THREADS) == 0, "BLOCK_THREADS must be a multiple of WARP_THREADS");

    // BlockExchange utility type for keys
    using BlockExchange = BlockExchange<T, BLOCK_DIM_X, ITEMS_PER_THREAD, false, BLOCK_DIM_Y, BLOCK_DIM_Z>;

    struct _TempStorage : BlockExchange::TempStorage
    {
      volatile int valid_items;
    };

    struct TempStorage : Uninitialized<_TempStorage>
    {};

    _TempStorage& temp_storage;

    int linear_tid;

    _CCCL_DEVICE _CCCL_FORCEINLINE StoreInternal(TempStorage& temp_storage, int linear_tid)
        : temp_storage(temp_storage.Alias())
        , linear_tid(linear_tid)
    {}

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
    {
      BlockExchange(temp_storage).BlockedToWarpStriped(items);
      StoreDirectWarpStriped(linear_tid, block_itr, items);
    }

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
    {
      BlockExchange(temp_storage).BlockedToWarpStriped(items);
      if (linear_tid == 0)
      {
        // Move through volatile smem as a workaround to prevent RF spilling on
        // subsequent loads
        temp_storage.valid_items = valid_items;
      }
      CTA_SYNC();
      StoreDirectWarpStriped(linear_tid, block_itr, items, temp_storage.valid_items);
    }
  };

  template <int DUMMY>
  struct StoreInternal<BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, DUMMY>
  {
    enum
    {
      WARP_THREADS = CUB_WARP_THREADS(0)
    };

    // Assert BLOCK_THREADS must be a multiple of WARP_THREADS
    static_assert(int(BLOCK_THREADS) % int(WARP_THREADS) == 0, "BLOCK_THREADS must be a multiple of WARP_THREADS");

    // BlockExchange utility type for keys
    using BlockExchange = BlockExchange<T, BLOCK_DIM_X, ITEMS_PER_THREAD, true, BLOCK_DIM_Y, BLOCK_DIM_Z>;

    struct _TempStorage : BlockExchange::TempStorage
    {
      volatile int valid_items;
    };

    struct TempStorage : Uninitialized<_TempStorage>
    {};

    _TempStorage& temp_storage;

    int linear_tid;

    _CCCL_DEVICE _CCCL_FORCEINLINE StoreInternal(TempStorage& temp_storage, int linear_tid)
        : temp_storage(temp_storage.Alias())
        , linear_tid(linear_tid)
    {}

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
    {
      BlockExchange(temp_storage).BlockedToWarpStriped(items);
      StoreDirectWarpStriped(linear_tid, block_itr, items);
    }

    template <typename OutputIteratorT>
    _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
    {
      BlockExchange(temp_storage).BlockedToWarpStriped(items);
      if (linear_tid == 0)
      {
        // Move through volatile smem as a workaround to prevent RF spilling on
        // subsequent loads
        temp_storage.valid_items = valid_items;
      }
      CTA_SYNC();
      StoreDirectWarpStriped(linear_tid, block_itr, items, temp_storage.valid_items);
    }
  };

  using InternalStore = StoreInternal<ALGORITHM, 0>;

  using _TempStorage = typename InternalStore::TempStorage;

  _CCCL_DEVICE _CCCL_FORCEINLINE _TempStorage& PrivateStorage()
  {
    __shared__ _TempStorage private_storage;
    return private_storage;
  }

  _TempStorage& temp_storage;

  int linear_tid;

public:
  struct TempStorage : Uninitialized<_TempStorage>
  {};

  _CCCL_DEVICE _CCCL_FORCEINLINE BlockStore()
      : temp_storage(PrivateStorage())
      , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
  {}

  _CCCL_DEVICE _CCCL_FORCEINLINE BlockStore(TempStorage& temp_storage)
      : temp_storage(temp_storage.Alias())
      , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
  {}

  template <typename OutputIteratorT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD])
  {
    InternalStore(temp_storage, linear_tid).Store(block_itr, items);
  }

  template <typename OutputIteratorT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void Store(OutputIteratorT block_itr, T (&items)[ITEMS_PER_THREAD], int valid_items)
  {
    InternalStore(temp_storage, linear_tid).Store(block_itr, items, valid_items);
  }

};

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
template <class Policy, class It, class T = cub::detail::value_t<It>>
struct BlockStoreType
{
  using type = cub::BlockStore<T, Policy::BLOCK_THREADS, Policy::ITEMS_PER_THREAD, Policy::STORE_ALGORITHM>;
};
#endif // _CCCL_DOXYGEN_INVOKED

CUB_NAMESPACE_END