cub/block/block_exchange.cuh

File members: cub/block/block_exchange.cuh

/******************************************************************************
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/detail/uninitialized_copy.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
#include <cub/warp/warp_exchange.cuh>

CUB_NAMESPACE_BEGIN

template <typename T,
          int BLOCK_DIM_X,
          int ITEMS_PER_THREAD,
          bool WARP_TIME_SLICING = false,
          int BLOCK_DIM_Y        = 1,
          int BLOCK_DIM_Z        = 1,
          int LEGACY_PTX_ARCH    = 0>
class BlockExchange
{
  static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z;
  static constexpr int WARP_THREADS  = CUB_WARP_THREADS(0);
  static constexpr int WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS; // TODO(bgruber): use ceil_div in
                                                                                  // C++14
  static constexpr int LOG_SMEM_BANKS = CUB_LOG_SMEM_BANKS(0);

  static constexpr int TILE_ITEMS          = BLOCK_THREADS * ITEMS_PER_THREAD;
  static constexpr int TIME_SLICES         = WARP_TIME_SLICING ? WARPS : 1;
  static constexpr int TIME_SLICED_THREADS = WARP_TIME_SLICING ? CUB_MIN(BLOCK_THREADS, WARP_THREADS) : BLOCK_THREADS;
  static constexpr int TIME_SLICED_ITEMS   = TIME_SLICED_THREADS * ITEMS_PER_THREAD;
  static constexpr int WARP_TIME_SLICED_THREADS = CUB_MIN(BLOCK_THREADS, WARP_THREADS);
  static constexpr int WARP_TIME_SLICED_ITEMS   = WARP_TIME_SLICED_THREADS * ITEMS_PER_THREAD;

  // Insert padding to avoid bank conflicts during raking when items per thread is a power of two and > 4 (otherwise
  // we can typically use 128b loads)
  static constexpr bool INSERT_PADDING = ITEMS_PER_THREAD > 4 && PowerOfTwo<ITEMS_PER_THREAD>::VALUE;
  static constexpr int PADDING_ITEMS   = INSERT_PADDING ? (TIME_SLICED_ITEMS >> LOG_SMEM_BANKS) : 0;

  struct alignas(16) _TempStorage
  {
    T buff[TIME_SLICED_ITEMS + PADDING_ITEMS];
  };

public:
  using TempStorage = Uninitialized<_TempStorage>;

private:
  _TempStorage& temp_storage;

  // TODO(bgruber): can we use signed int here? Only these variables are unsigned:
  unsigned int linear_tid  = RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z);
  unsigned int lane_id     = LaneId();
  unsigned int warp_id     = WARPS == 1 ? 0 : linear_tid / WARP_THREADS;
  unsigned int warp_offset = warp_id * WARP_TIME_SLICED_ITEMS;

  _CCCL_DEVICE _CCCL_FORCEINLINE _TempStorage& PrivateStorage()
  {
    __shared__ _TempStorage private_storage;
    return private_storage;
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void BlockedToStriped(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    Int2Type<false> /*time_slicing*/)
  {
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = linear_tid * ITEMS_PER_THREAD + i;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset += item_offset >> LOG_SMEM_BANKS;
      }
      detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
    }

    CTA_SYNC();

#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = i * BLOCK_THREADS + linear_tid;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset += item_offset >> LOG_SMEM_BANKS;
      }
      output_items[i] = temp_storage.buff[item_offset];
    }
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void BlockedToStriped(
    const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD], Int2Type<true> /*time_slicing*/)
  {
    T temp_items[ITEMS_PER_THREAD];

#pragma unroll
    for (int slice = 0; slice < TIME_SLICES; slice++)
    {
      const int slice_offset = slice * TIME_SLICED_ITEMS;
      const int slice_oob    = slice_offset + TIME_SLICED_ITEMS;

      CTA_SYNC();

      if (warp_id == slice)
      {
#pragma unroll
        for (int i = 0; i < ITEMS_PER_THREAD; i++)
        {
          int item_offset = lane_id * ITEMS_PER_THREAD + i;
          _CCCL_IF_CONSTEXPR (INSERT_PADDING)
          {
            item_offset += item_offset >> LOG_SMEM_BANKS;
          }
          detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
        }
      }

      CTA_SYNC();

#pragma unroll
      for (int i = 0; i < ITEMS_PER_THREAD; i++)
      {
        // Read a strip of items
        const int strip_offset = i * BLOCK_THREADS;
        const int strip_oob    = strip_offset + BLOCK_THREADS;

        if (slice_offset < strip_oob && slice_oob > strip_offset)
        {
          int item_offset = strip_offset + linear_tid - slice_offset;
          if (item_offset >= 0 && item_offset < TIME_SLICED_ITEMS)
          {
            _CCCL_IF_CONSTEXPR (INSERT_PADDING)
            {
              item_offset += item_offset >> LOG_SMEM_BANKS;
            }
            temp_items[i] = temp_storage.buff[item_offset];
          }
        }
      }
    }

// Copy
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      output_items[i] = temp_items[i];
    }
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void BlockedToWarpStriped(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    Int2Type<false> /*time_slicing*/)
  {
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = warp_offset + i + (lane_id * ITEMS_PER_THREAD);
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset += item_offset >> LOG_SMEM_BANKS;
      }
      detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
    }

    WARP_SYNC(0xffffffff);

#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = warp_offset + (i * WARP_TIME_SLICED_THREADS) + lane_id;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset += item_offset >> LOG_SMEM_BANKS;
      }
      output_items[i] = temp_storage.buff[item_offset];
    }
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void BlockedToWarpStriped(
    const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD], Int2Type<true> /*time_slicing*/)
  {
    if (warp_id == 0)
    {
#pragma unroll
      for (int i = 0; i < ITEMS_PER_THREAD; i++)
      {
        int item_offset = i + lane_id * ITEMS_PER_THREAD;
        _CCCL_IF_CONSTEXPR (INSERT_PADDING)
        {
          item_offset += item_offset >> LOG_SMEM_BANKS;
        }
        detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
      }

      WARP_SYNC(0xffffffff);

#pragma unroll
      for (int i = 0; i < ITEMS_PER_THREAD; i++)
      {
        int item_offset = i * WARP_TIME_SLICED_THREADS + lane_id;
        _CCCL_IF_CONSTEXPR (INSERT_PADDING)
        {
          item_offset += item_offset >> LOG_SMEM_BANKS;
        }
        output_items[i] = temp_storage.buff[item_offset];
      }
    }

#pragma unroll
    for (int slice = 1; slice < TIME_SLICES; ++slice)
    {
      CTA_SYNC();

      if (warp_id == slice)
      {
#pragma unroll
        for (int i = 0; i < ITEMS_PER_THREAD; i++)
        {
          int item_offset = i + lane_id * ITEMS_PER_THREAD;
          _CCCL_IF_CONSTEXPR (INSERT_PADDING)
          {
            item_offset += item_offset >> LOG_SMEM_BANKS;
          }
          detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
        }

        WARP_SYNC(0xffffffff);

#pragma unroll
        for (int i = 0; i < ITEMS_PER_THREAD; i++)
        {
          int item_offset = i * WARP_TIME_SLICED_THREADS + lane_id;
          _CCCL_IF_CONSTEXPR (INSERT_PADDING)
          {
            item_offset += item_offset >> LOG_SMEM_BANKS;
          }
          output_items[i] = temp_storage.buff[item_offset];
        }
      }
    }
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void StripedToBlocked(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    Int2Type<false> /*time_slicing*/)
  {
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = i * BLOCK_THREADS + linear_tid;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset += item_offset >> LOG_SMEM_BANKS;
      }
      detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
    }

    CTA_SYNC();

// No timeslicing
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = linear_tid * ITEMS_PER_THREAD + i;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset += item_offset >> LOG_SMEM_BANKS;
      }
      output_items[i] = temp_storage.buff[item_offset];
    }
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void StripedToBlocked(
    const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD], Int2Type<true> /*time_slicing*/)
  {
    // Warp time-slicing
    T temp_items[ITEMS_PER_THREAD];

#pragma unroll
    for (int slice = 0; slice < TIME_SLICES; slice++)
    {
      const int slice_offset = slice * TIME_SLICED_ITEMS;
      const int slice_oob    = slice_offset + TIME_SLICED_ITEMS;

      CTA_SYNC();

#pragma unroll
      for (int i = 0; i < ITEMS_PER_THREAD; i++)
      {
        // Write a strip of items
        const int strip_offset = i * BLOCK_THREADS;
        const int strip_oob    = strip_offset + BLOCK_THREADS;

        if (slice_offset < strip_oob && slice_oob > strip_offset)
        {
          int item_offset = strip_offset + linear_tid - slice_offset;
          if (item_offset >= 0 && item_offset < TIME_SLICED_ITEMS)
          {
            _CCCL_IF_CONSTEXPR (INSERT_PADDING)
            {
              item_offset += item_offset >> LOG_SMEM_BANKS;
            }
            detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
          }
        }
      }

      CTA_SYNC();

      if (warp_id == slice)
      {
#pragma unroll
        for (int i = 0; i < ITEMS_PER_THREAD; i++)
        {
          int item_offset = lane_id * ITEMS_PER_THREAD + i;
          _CCCL_IF_CONSTEXPR (INSERT_PADDING)
          {
            item_offset += item_offset >> LOG_SMEM_BANKS;
          }
          temp_items[i] = temp_storage.buff[item_offset];
        }
      }
    }

// Copy
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      output_items[i] = temp_items[i];
    }
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void WarpStripedToBlocked(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    Int2Type<false> /*time_slicing*/)
  {
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = warp_offset + (i * WARP_TIME_SLICED_THREADS) + lane_id;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset += item_offset >> LOG_SMEM_BANKS;
      }
      detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
    }

    WARP_SYNC(0xffffffff);

#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = warp_offset + i + (lane_id * ITEMS_PER_THREAD);
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset += item_offset >> LOG_SMEM_BANKS;
      }
      detail::uninitialized_copy_single(output_items + i, temp_storage.buff[item_offset]);
    }
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void WarpStripedToBlocked(
    const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD], Int2Type<true> /*time_slicing*/)
  {
#pragma unroll
    for (int slice = 0; slice < TIME_SLICES; ++slice)
    {
      CTA_SYNC();

      if (warp_id == slice)
      {
#pragma unroll
        for (int i = 0; i < ITEMS_PER_THREAD; i++)
        {
          int item_offset = i * WARP_TIME_SLICED_THREADS + lane_id;
          _CCCL_IF_CONSTEXPR (INSERT_PADDING)
          {
            item_offset += item_offset >> LOG_SMEM_BANKS;
          }
          detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
        }

        WARP_SYNC(0xffffffff);

#pragma unroll
        for (int i = 0; i < ITEMS_PER_THREAD; i++)
        {
          int item_offset = i + lane_id * ITEMS_PER_THREAD;
          _CCCL_IF_CONSTEXPR (INSERT_PADDING)
          {
            item_offset += item_offset >> LOG_SMEM_BANKS;
          }
          output_items[i] = temp_storage.buff[item_offset];
        }
      }
    }
  }

  template <typename OutputT, typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToBlocked(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    OffsetT (&ranks)[ITEMS_PER_THREAD],
    Int2Type<false> /*time_slicing*/)
  {
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = ranks[i];
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
      }
      detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
    }

    CTA_SYNC();

#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = linear_tid * ITEMS_PER_THREAD + i;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
      }
      output_items[i] = temp_storage.buff[item_offset];
    }
  }

  template <typename OutputT, typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToBlocked(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    OffsetT ranks[ITEMS_PER_THREAD],
    Int2Type<true> /*time_slicing*/)
  {
    T temp_items[ITEMS_PER_THREAD];

#pragma unroll
    for (int slice = 0; slice < TIME_SLICES; slice++)
    {
      CTA_SYNC();

      const int slice_offset = TIME_SLICED_ITEMS * slice;

#pragma unroll
      for (int i = 0; i < ITEMS_PER_THREAD; i++)
      {
        int item_offset = ranks[i] - slice_offset;
        if (item_offset >= 0 && item_offset < WARP_TIME_SLICED_ITEMS)
        {
          _CCCL_IF_CONSTEXPR (INSERT_PADDING)
          {
            item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
          }
          detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
        }
      }

      CTA_SYNC();

      if (warp_id == slice)
      {
#pragma unroll
        for (int i = 0; i < ITEMS_PER_THREAD; i++)
        {
          int item_offset = lane_id * ITEMS_PER_THREAD + i;
          _CCCL_IF_CONSTEXPR (INSERT_PADDING)
          {
            item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
          }
          temp_items[i] = temp_storage.buff[item_offset];
        }
      }
    }

// Copy
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      output_items[i] = temp_items[i];
    }
  }

  template <typename OutputT, typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToStriped(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    OffsetT (&ranks)[ITEMS_PER_THREAD],
    Int2Type<false> /*time_slicing*/)
  {
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = ranks[i];
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
      }
      detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
    }

    CTA_SYNC();

#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = i * BLOCK_THREADS + linear_tid;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
      }
      output_items[i] = temp_storage.buff[item_offset];
    }
  }

  template <typename OutputT, typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToStriped(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    OffsetT (&ranks)[ITEMS_PER_THREAD],
    Int2Type<true> /*time_slicing*/)
  {
    T temp_items[ITEMS_PER_THREAD];

#pragma unroll
    for (int slice = 0; slice < TIME_SLICES; slice++)
    {
      const int slice_offset = slice * TIME_SLICED_ITEMS;
      const int slice_oob    = slice_offset + TIME_SLICED_ITEMS;

      CTA_SYNC();

#pragma unroll
      for (int i = 0; i < ITEMS_PER_THREAD; i++)
      {
        int item_offset = ranks[i] - slice_offset;
        if (item_offset >= 0 && item_offset < WARP_TIME_SLICED_ITEMS)
        {
          _CCCL_IF_CONSTEXPR (INSERT_PADDING)
          {
            item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
          }
          detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
        }
      }

      CTA_SYNC();

#pragma unroll
      for (int i = 0; i < ITEMS_PER_THREAD; i++)
      {
        // Read a strip of items
        const int strip_offset = i * BLOCK_THREADS;
        const int strip_oob    = strip_offset + BLOCK_THREADS;

        if (slice_offset < strip_oob && slice_oob > strip_offset)
        {
          int item_offset = strip_offset + linear_tid - slice_offset;
          if (item_offset >= 0 && item_offset < TIME_SLICED_ITEMS)
          {
            _CCCL_IF_CONSTEXPR (INSERT_PADDING)
            {
              item_offset += item_offset >> LOG_SMEM_BANKS;
            }
            temp_items[i] = temp_storage.buff[item_offset];
          }
        }
      }
    }

// Copy
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      output_items[i] = temp_items[i];
    }
  }

public:

  _CCCL_DEVICE _CCCL_FORCEINLINE BlockExchange()
      : temp_storage(PrivateStorage())
  {}

  _CCCL_DEVICE _CCCL_FORCEINLINE BlockExchange(TempStorage& temp_storage)
      : temp_storage(temp_storage.Alias())
  {}

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  StripedToBlocked(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD])
  {
    StripedToBlocked(input_items, output_items, Int2Type<WARP_TIME_SLICING>());
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  BlockedToStriped(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD])
  {
    BlockedToStriped(input_items, output_items, Int2Type<WARP_TIME_SLICING>());
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  WarpStripedToBlocked(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD])
  {
    WarpStripedToBlocked(input_items, output_items, Int2Type<WARP_TIME_SLICING>());
  }

  template <typename OutputT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  BlockedToWarpStriped(const T (&input_items)[ITEMS_PER_THREAD], OutputT (&output_items)[ITEMS_PER_THREAD])
  {
    BlockedToWarpStriped(input_items, output_items, Int2Type<WARP_TIME_SLICING>());
  }

  template <typename OutputT, typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToBlocked(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    OffsetT (&ranks)[ITEMS_PER_THREAD])
  {
    ScatterToBlocked(input_items, output_items, ranks, Int2Type<WARP_TIME_SLICING>());
  }

  template <typename OutputT, typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToStriped(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    OffsetT (&ranks)[ITEMS_PER_THREAD])
  {
    ScatterToStriped(input_items, output_items, ranks, Int2Type<WARP_TIME_SLICING>());
  }

  template <typename OutputT, typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToStripedGuarded(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    OffsetT (&ranks)[ITEMS_PER_THREAD])
  {
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = ranks[i];
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
      }
      if (ranks[i] >= 0)
      {
        temp_storage.buff[item_offset] = input_items[i];
      }
    }

    CTA_SYNC();

#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = i * BLOCK_THREADS + linear_tid;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
      }
      output_items[i] = temp_storage.buff[item_offset];
    }
  }

  template <typename OutputT, typename OffsetT, typename ValidFlag>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToStripedFlagged(
    const T (&input_items)[ITEMS_PER_THREAD],
    OutputT (&output_items)[ITEMS_PER_THREAD],
    OffsetT (&ranks)[ITEMS_PER_THREAD],
    ValidFlag (&is_valid)[ITEMS_PER_THREAD])
  {
#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = ranks[i];
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
      }
      if (is_valid[i])
      {
        temp_storage.buff[item_offset] = input_items[i];
      }
    }

    CTA_SYNC();

#pragma unroll
    for (int i = 0; i < ITEMS_PER_THREAD; i++)
    {
      int item_offset = i * BLOCK_THREADS + linear_tid;
      _CCCL_IF_CONSTEXPR (INSERT_PADDING)
      {
        item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
      }
      output_items[i] = temp_storage.buff[item_offset];
    }
  }

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document

  _CCCL_DEVICE _CCCL_FORCEINLINE void StripedToBlocked(T (&items)[ITEMS_PER_THREAD])
  {
    StripedToBlocked(items, items);
  }

  _CCCL_DEVICE _CCCL_FORCEINLINE void BlockedToStriped(T (&items)[ITEMS_PER_THREAD])
  {
    BlockedToStriped(items, items);
  }

  _CCCL_DEVICE _CCCL_FORCEINLINE void WarpStripedToBlocked(T (&items)[ITEMS_PER_THREAD])
  {
    WarpStripedToBlocked(items, items);
  }

  _CCCL_DEVICE _CCCL_FORCEINLINE void BlockedToWarpStriped(T (&items)[ITEMS_PER_THREAD])
  {
    BlockedToWarpStriped(items, items);
  }

  template <typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToBlocked(T (&items)[ITEMS_PER_THREAD], OffsetT (&ranks)[ITEMS_PER_THREAD])
  {
    ScatterToBlocked(items, items, ranks);
  }

  template <typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToStriped(T (&items)[ITEMS_PER_THREAD], OffsetT (&ranks)[ITEMS_PER_THREAD])
  {
    ScatterToStriped(items, items, ranks);
  }

  template <typename OffsetT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  ScatterToStripedGuarded(T (&items)[ITEMS_PER_THREAD], OffsetT (&ranks)[ITEMS_PER_THREAD])
  {
    ScatterToStripedGuarded(items, items, ranks);
  }

  template <typename OffsetT, typename ValidFlag>
  _CCCL_DEVICE _CCCL_FORCEINLINE void ScatterToStripedFlagged(
    T (&items)[ITEMS_PER_THREAD], OffsetT (&ranks)[ITEMS_PER_THREAD], ValidFlag (&is_valid)[ITEMS_PER_THREAD])
  {
    ScatterToStriped(items, items, ranks, is_valid);
  }

#endif // _CCCL_DOXYGEN_INVOKED
};

CUB_NAMESPACE_END