cub/block/block_radix_rank.cuh

File members: cub/block/block_radix_rank.cuh

/******************************************************************************
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/block/block_scan.cuh>
#include <cub/block/radix_rank_sort_operations.cuh>
#include <cub/thread/thread_reduce.cuh>
#include <cub/thread/thread_scan.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/cstdint>
#include <cuda/std/limits>
#include <cuda/std/type_traits>

CUB_NAMESPACE_BEGIN

enum RadixRankAlgorithm
{
  RADIX_RANK_BASIC,

  RADIX_RANK_MEMOIZE,

  RADIX_RANK_MATCH,

  RADIX_RANK_MATCH_EARLY_COUNTS_ANY,

  RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR
};

template <int BINS_PER_THREAD>
struct BlockRadixRankEmptyCallback
{
  _CCCL_DEVICE _CCCL_FORCEINLINE void operator()(int (&bins)[BINS_PER_THREAD]) {}
};

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
namespace detail
{

template <int Bits, int PartialWarpThreads, int PartialWarpId>
struct warp_in_block_matcher_t
{
  static _CCCL_DEVICE ::cuda::std::uint32_t match_any(::cuda::std::uint32_t label, ::cuda::std::uint32_t warp_id)
  {
    if (warp_id == static_cast<::cuda::std::uint32_t>(PartialWarpId))
    {
      return MatchAny<Bits, PartialWarpThreads>(label);
    }

    return MatchAny<Bits>(label);
  }
};

template <int Bits, int PartialWarpId>
struct warp_in_block_matcher_t<Bits, 0, PartialWarpId>
{
  static _CCCL_DEVICE ::cuda::std::uint32_t match_any(::cuda::std::uint32_t label, ::cuda::std::uint32_t warp_id)
  {
    return MatchAny<Bits>(label);
  }
};

} // namespace detail
#endif // _CCCL_DOXYGEN_INVOKED

template <int BLOCK_DIM_X,
          int RADIX_BITS,
          bool IS_DESCENDING,
          bool MEMOIZE_OUTER_SCAN                 = true,
          BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
          cudaSharedMemConfig SMEM_CONFIG         = cudaSharedMemBankSizeFourByte,
          int BLOCK_DIM_Y                         = 1,
          int BLOCK_DIM_Z                         = 1,
          int LEGACY_PTX_ARCH                     = 0>
class BlockRadixRank
{
private:
  // Integer type for digit counters (to be packed into words of type PackedCounters)
  using DigitCounter = unsigned short;

  // Integer type for packing DigitCounters into columns of shared memory banks
  using PackedCounter =
    ::cuda::std::_If<SMEM_CONFIG == cudaSharedMemBankSizeEightByte, unsigned long long, unsigned int>;

  static constexpr DigitCounter max_tile_size = ::cuda::std::numeric_limits<DigitCounter>::max();

  enum
  {
    // The thread block size in threads
    BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,

    RADIX_DIGITS = 1 << RADIX_BITS,

    LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(0),
    WARP_THREADS     = 1 << LOG_WARP_THREADS,
    WARPS            = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,

    BYTES_PER_COUNTER     = sizeof(DigitCounter),
    LOG_BYTES_PER_COUNTER = Log2<BYTES_PER_COUNTER>::VALUE,

    PACKING_RATIO     = static_cast<int>(sizeof(PackedCounter) / sizeof(DigitCounter)),
    LOG_PACKING_RATIO = Log2<PACKING_RATIO>::VALUE,

    // Always at least one lane
    LOG_COUNTER_LANES = CUB_MAX((int(RADIX_BITS) - int(LOG_PACKING_RATIO)), 0),
    COUNTER_LANES     = 1 << LOG_COUNTER_LANES,

    // The number of packed counters per thread (plus one for padding)
    PADDED_COUNTER_LANES = COUNTER_LANES + 1,
    RAKING_SEGMENT       = PADDED_COUNTER_LANES,
  };

public:
  enum
  {
    BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
  };

private:
  using BlockScan = BlockScan<PackedCounter, BLOCK_DIM_X, INNER_SCAN_ALGORITHM, BLOCK_DIM_Y, BLOCK_DIM_Z>;

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
  struct __align__(16) _TempStorage
  {
    union Aliasable
    {
      DigitCounter digit_counters[PADDED_COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO];
      PackedCounter raking_grid[BLOCK_THREADS][RAKING_SEGMENT];

    } aliasable;

    // Storage for scanning local ranks
    typename BlockScan::TempStorage block_scan;
  };
#endif // !_CCCL_DOXYGEN_INVOKED

  _TempStorage& temp_storage;

  unsigned int linear_tid;

  PackedCounter cached_segment[RAKING_SEGMENT];

  _CCCL_DEVICE _CCCL_FORCEINLINE _TempStorage& PrivateStorage()
  {
    __shared__ _TempStorage private_storage;
    return private_storage;
  }

  _CCCL_DEVICE _CCCL_FORCEINLINE PackedCounter Upsweep()
  {
    PackedCounter* smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid];
    PackedCounter* raking_ptr;

    if (MEMOIZE_OUTER_SCAN)
    {
// Copy data into registers
#pragma unroll
      for (int i = 0; i < RAKING_SEGMENT; i++)
      {
        cached_segment[i] = smem_raking_ptr[i];
      }
      raking_ptr = cached_segment;
    }
    else
    {
      raking_ptr = smem_raking_ptr;
    }

    return cub::internal::ThreadReduce<RAKING_SEGMENT>(raking_ptr, ::cuda::std::plus<>{});
  }

  _CCCL_DEVICE _CCCL_FORCEINLINE void ExclusiveDownsweep(PackedCounter raking_partial)
  {
    PackedCounter* smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid];

    PackedCounter* raking_ptr = (MEMOIZE_OUTER_SCAN) ? cached_segment : smem_raking_ptr;

    // Exclusive raking downsweep scan
    internal::ThreadScanExclusive<RAKING_SEGMENT>(raking_ptr, raking_ptr, ::cuda::std::plus<>{}, raking_partial);

    if (MEMOIZE_OUTER_SCAN)
    {
// Copy data back to smem
#pragma unroll
      for (int i = 0; i < RAKING_SEGMENT; i++)
      {
        smem_raking_ptr[i] = cached_segment[i];
      }
    }
  }

  _CCCL_DEVICE _CCCL_FORCEINLINE void ResetCounters()
  {
// Reset shared memory digit counters
#pragma unroll
    for (int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++)
    {
      *((PackedCounter*) temp_storage.aliasable.digit_counters[LANE][linear_tid]) = 0;
    }
  }

  struct PrefixCallBack
  {
    _CCCL_DEVICE _CCCL_FORCEINLINE PackedCounter operator()(PackedCounter block_aggregate)
    {
      PackedCounter block_prefix = 0;

// Propagate totals in packed fields
#pragma unroll
      for (int PACKED = 1; PACKED < PACKING_RATIO; PACKED++)
      {
        block_prefix += block_aggregate << (sizeof(DigitCounter) * 8 * PACKED);
      }

      return block_prefix;
    }
  };

  _CCCL_DEVICE _CCCL_FORCEINLINE void ScanCounters()
  {
    // Upsweep scan
    PackedCounter raking_partial = Upsweep();

    // Compute exclusive sum
    PackedCounter exclusive_partial;
    PrefixCallBack prefix_call_back;
    BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back);

    // Downsweep scan with exclusive partial
    ExclusiveDownsweep(exclusive_partial);
  }

public:
  struct TempStorage : Uninitialized<_TempStorage>
  {};

  _CCCL_DEVICE _CCCL_FORCEINLINE BlockRadixRank()
      : temp_storage(PrivateStorage())
      , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
  {}

  _CCCL_DEVICE _CCCL_FORCEINLINE BlockRadixRank(TempStorage& temp_storage)
      : temp_storage(temp_storage.Alias())
      , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
  {}

  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  RankKeys(UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
  {
    static_assert(BLOCK_THREADS * KEYS_PER_THREAD <= max_tile_size,
                  "DigitCounter type is too small to hold this number of keys");

    DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the
                                                   // same digit
    DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter
                                                   // in smem

    // Reset shared memory digit counters
    ResetCounters();

#pragma unroll
    for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
    {
      // Get digit
      ::cuda::std::uint32_t digit = digit_extractor.Digit(keys[ITEM]);

      // Get sub-counter
      ::cuda::std::uint32_t sub_counter = digit >> LOG_COUNTER_LANES;

      // Get counter lane
      ::cuda::std::uint32_t counter_lane = digit & (COUNTER_LANES - 1);

      if (IS_DESCENDING)
      {
        sub_counter  = PACKING_RATIO - 1 - sub_counter;
        counter_lane = COUNTER_LANES - 1 - counter_lane;
      }

      // Pointer to smem digit counter
      digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane][linear_tid][sub_counter];

      // Load thread-exclusive prefix
      thread_prefixes[ITEM] = *digit_counters[ITEM];

      // Store inclusive prefix
      *digit_counters[ITEM] = thread_prefixes[ITEM] + 1;
    }

    CTA_SYNC();

    // Scan shared memory counters
    ScanCounters();

    CTA_SYNC();

// Extract the local ranks of each key
#pragma unroll
    for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
    {
      // Add in thread block exclusive prefix
      ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM];
    }
  }

  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  RankKeys(UnsignedBits (&keys)[KEYS_PER_THREAD],
           int (&ranks)[KEYS_PER_THREAD],
           DigitExtractorT digit_extractor,
           int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
  {
    static_assert(BLOCK_THREADS * KEYS_PER_THREAD <= max_tile_size,
                  "DigitCounter type is too small to hold this number of keys");

    // Rank keys
    RankKeys(keys, ranks, digit_extractor);

// Get the inclusive and exclusive digit totals corresponding to the calling thread.
#pragma unroll
    for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
    {
      int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;

      if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
      {
        if (IS_DESCENDING)
        {
          bin_idx = RADIX_DIGITS - bin_idx - 1;
        }

        // Obtain ex/inclusive digit counts.  (Unfortunately these all reside in the
        // first counter column, resulting in unavoidable bank conflicts.)
        unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1));
        unsigned int sub_counter  = bin_idx >> (LOG_COUNTER_LANES);

        exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counters[counter_lane][0][sub_counter];
      }
    }
  }

};

template <int BLOCK_DIM_X,
          int RADIX_BITS,
          bool IS_DESCENDING,
          BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
          int BLOCK_DIM_Y                         = 1,
          int BLOCK_DIM_Z                         = 1,
          int LEGACY_PTX_ARCH                     = 0>
class BlockRadixRankMatch
{
private:
  using RankT         = int32_t;
  using DigitCounterT = int32_t;

  enum
  {
    // The thread block size in threads
    BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,

    RADIX_DIGITS = 1 << RADIX_BITS,

    LOG_WARP_THREADS     = CUB_LOG_WARP_THREADS(0),
    WARP_THREADS         = 1 << LOG_WARP_THREADS,
    PARTIAL_WARP_THREADS = BLOCK_THREADS % WARP_THREADS,
    WARPS                = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,

    PADDED_WARPS = ((WARPS & 0x1) == 0) ? WARPS + 1 : WARPS,

    COUNTERS              = PADDED_WARPS * RADIX_DIGITS,
    RAKING_SEGMENT        = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS,
    PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ? RAKING_SEGMENT + 1 : RAKING_SEGMENT,
  };

public:
  enum
  {
    BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
  };

private:
  using BlockScanT = BlockScan<DigitCounterT, BLOCK_THREADS, INNER_SCAN_ALGORITHM, BLOCK_DIM_Y, BLOCK_DIM_Z>;

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
  struct __align__(16) _TempStorage
  {
    typename BlockScanT::TempStorage block_scan;

    union __align__(16) Aliasable
    {
      volatile DigitCounterT warp_digit_counters[RADIX_DIGITS][PADDED_WARPS];
      DigitCounterT raking_grid[BLOCK_THREADS][PADDED_RAKING_SEGMENT];
    }
    aliasable;
  };
#endif // !_CCCL_DOXYGEN_INVOKED

  _TempStorage& temp_storage;

  unsigned int linear_tid;

public:
  struct TempStorage : Uninitialized<_TempStorage>
  {};

  _CCCL_DEVICE _CCCL_FORCEINLINE BlockRadixRankMatch(TempStorage& temp_storage)
      : temp_storage(temp_storage.Alias())
      , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
  {}

  template <int KEYS_PER_THREAD, typename CountsCallback>
  _CCCL_DEVICE _CCCL_FORCEINLINE void CallBack(CountsCallback callback)
  {
    int bins[BINS_TRACKED_PER_THREAD];
// Get count for each digit
#pragma unroll
    for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
    {
      int bin_idx              = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
      constexpr int TILE_ITEMS = KEYS_PER_THREAD * BLOCK_THREADS;

      if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
      {
        if (IS_DESCENDING)
        {
          bin_idx     = RADIX_DIGITS - bin_idx - 1;
          bins[track] = (bin_idx > 0 ? temp_storage.aliasable.warp_digit_counters[bin_idx - 1][0] : TILE_ITEMS)
                      - temp_storage.aliasable.warp_digit_counters[bin_idx][0];
        }
        else
        {
          bins[track] =
            (bin_idx < RADIX_DIGITS - 1 ? temp_storage.aliasable.warp_digit_counters[bin_idx + 1][0] : TILE_ITEMS)
            - temp_storage.aliasable.warp_digit_counters[bin_idx][0];
        }
      }
    }
    callback(bins);
  }

  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT, typename CountsCallback>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  RankKeys(UnsignedBits (&keys)[KEYS_PER_THREAD],
           int (&ranks)[KEYS_PER_THREAD],
           DigitExtractorT digit_extractor,
           CountsCallback callback)
  {
    // Initialize shared digit counters

#pragma unroll
    for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
    {
      temp_storage.aliasable.raking_grid[linear_tid][ITEM] = 0;
    }

    CTA_SYNC();

    // Each warp will strip-mine its section of input, one strip at a time

    volatile DigitCounterT* digit_counters[KEYS_PER_THREAD];
    uint32_t warp_id      = linear_tid >> LOG_WARP_THREADS;
    uint32_t lane_mask_lt = LaneMaskLt();

#pragma unroll
    for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
    {
      // My digit
      ::cuda::std::uint32_t digit = digit_extractor.Digit(keys[ITEM]);

      if (IS_DESCENDING)
      {
        digit = RADIX_DIGITS - digit - 1;
      }

      // Mask of peers who have same digit as me
      uint32_t peer_mask =
        detail::warp_in_block_matcher_t<RADIX_BITS, PARTIAL_WARP_THREADS, WARPS - 1>::match_any(digit, warp_id);

      // Pointer to smem digit counter for this key
      digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit][warp_id];

      // Number of occurrences in previous strips
      DigitCounterT warp_digit_prefix = *digit_counters[ITEM];

      // Warp-sync
      WARP_SYNC(0xFFFFFFFF);

      // Number of peers having same digit as me
      int32_t digit_count = __popc(peer_mask);

      // Number of lower-ranked peers having same digit seen so far
      int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt);

      if (peer_digit_prefix == 0)
      {
        // First thread for each digit updates the shared warp counter
        *digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count);
      }

      // Warp-sync
      WARP_SYNC(0xFFFFFFFF);

      // Number of prior keys having same digit
      ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix);
    }

    CTA_SYNC();

    // Scan warp counters

    DigitCounterT scan_counters[PADDED_RAKING_SEGMENT];

#pragma unroll
    for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
    {
      scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid][ITEM];
    }

    BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters);

#pragma unroll
    for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
    {
      temp_storage.aliasable.raking_grid[linear_tid][ITEM] = scan_counters[ITEM];
    }

    CTA_SYNC();
    if (!::cuda::std::is_same<CountsCallback, BlockRadixRankEmptyCallback<BINS_TRACKED_PER_THREAD>>::value)
    {
      CallBack<KEYS_PER_THREAD>(callback);
    }

// Seed ranks with counter values from previous warps
#pragma unroll
    for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
    {
      ranks[ITEM] += *digit_counters[ITEM];
    }
  }

  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  RankKeys(UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
  {
    RankKeys(keys, ranks, digit_extractor, BlockRadixRankEmptyCallback<BINS_TRACKED_PER_THREAD>());
  }

  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT, typename CountsCallback>
  _CCCL_DEVICE _CCCL_FORCEINLINE void RankKeys(
    UnsignedBits (&keys)[KEYS_PER_THREAD],
    int (&ranks)[KEYS_PER_THREAD],
    DigitExtractorT digit_extractor,
    int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD],
    CountsCallback callback)
  {
    RankKeys(keys, ranks, digit_extractor, callback);

// Get exclusive count for each digit
#pragma unroll
    for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
    {
      int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;

      if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
      {
        if (IS_DESCENDING)
        {
          bin_idx = RADIX_DIGITS - bin_idx - 1;
        }

        exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx][0];
      }
    }
  }

  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  RankKeys(UnsignedBits (&keys)[KEYS_PER_THREAD],
           int (&ranks)[KEYS_PER_THREAD],
           DigitExtractorT digit_extractor,
           int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
  {
    RankKeys(
      keys, ranks, digit_extractor, exclusive_digit_prefix, BlockRadixRankEmptyCallback<BINS_TRACKED_PER_THREAD>());
  }

};

enum WarpMatchAlgorithm
{
  WARP_MATCH_ANY,
  WARP_MATCH_ATOMIC_OR
};

template <int BLOCK_DIM_X,
          int RADIX_BITS,
          bool IS_DESCENDING,
          BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
          WarpMatchAlgorithm MATCH_ALGORITHM      = WARP_MATCH_ANY,
          int NUM_PARTS                           = 1>
struct BlockRadixRankMatchEarlyCounts
{
  // constants
  enum
  {
    BLOCK_THREADS           = BLOCK_DIM_X,
    RADIX_DIGITS            = 1 << RADIX_BITS,
    BINS_PER_THREAD         = (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS,
    BINS_TRACKED_PER_THREAD = BINS_PER_THREAD,
    FULL_BINS               = BINS_PER_THREAD * BLOCK_THREADS == RADIX_DIGITS,
    WARP_THREADS            = CUB_PTX_WARP_THREADS,
    PARTIAL_WARP_THREADS    = BLOCK_THREADS % WARP_THREADS,
    BLOCK_WARPS             = BLOCK_THREADS / WARP_THREADS,
    PARTIAL_WARP_ID         = BLOCK_WARPS - 1,
    WARP_MASK               = ~0,
    NUM_MATCH_MASKS         = MATCH_ALGORITHM == WARP_MATCH_ATOMIC_OR ? BLOCK_WARPS : 0,
    // Guard against declaring zero-sized array:
    MATCH_MASKS_ALLOC_SIZE = NUM_MATCH_MASKS < 1 ? 1 : NUM_MATCH_MASKS,
  };

  // types
  using BlockScan = cub::BlockScan<int, BLOCK_THREADS, INNER_SCAN_ALGORITHM>;

  struct TempStorage
  {
    union
    {
      int warp_offsets[BLOCK_WARPS][RADIX_DIGITS];
      int warp_histograms[BLOCK_WARPS][RADIX_DIGITS][NUM_PARTS];
    };

    int match_masks[MATCH_MASKS_ALLOC_SIZE][RADIX_DIGITS];

    typename BlockScan::TempStorage prefix_tmp;
  };

  TempStorage& temp_storage;

  // internal ranking implementation
  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT, typename CountsCallback>
  struct BlockRadixRankMatchInternal
  {
    TempStorage& s;
    DigitExtractorT digit_extractor;
    CountsCallback callback;
    int warp;
    int lane;

    _CCCL_DEVICE _CCCL_FORCEINLINE ::cuda::std::uint32_t Digit(UnsignedBits key)
    {
      ::cuda::std::uint32_t digit = digit_extractor.Digit(key);
      return IS_DESCENDING ? RADIX_DIGITS - 1 - digit : digit;
    }

    _CCCL_DEVICE _CCCL_FORCEINLINE int ThreadBin(int u)
    {
      int bin = threadIdx.x * BINS_PER_THREAD + u;
      return IS_DESCENDING ? RADIX_DIGITS - 1 - bin : bin;
    }

    _CCCL_DEVICE _CCCL_FORCEINLINE void ComputeHistogramsWarp(UnsignedBits (&keys)[KEYS_PER_THREAD])
    {
      // int* warp_offsets = &s.warp_offsets[warp][0];
      int(&warp_histograms)[RADIX_DIGITS][NUM_PARTS] = s.warp_histograms[warp];
// compute warp-private histograms
#pragma unroll
      for (int bin = lane; bin < RADIX_DIGITS; bin += WARP_THREADS)
      {
#pragma unroll
        for (int part = 0; part < NUM_PARTS; ++part)
        {
          warp_histograms[bin][part] = 0;
        }
      }
      if (MATCH_ALGORITHM == WARP_MATCH_ATOMIC_OR)
      {
        int* match_masks = &s.match_masks[warp][0];
#pragma unroll
        for (int bin = lane; bin < RADIX_DIGITS; bin += WARP_THREADS)
        {
          match_masks[bin] = 0;
        }
      }
      WARP_SYNC(WARP_MASK);

      // compute private per-part histograms
      int part = lane % NUM_PARTS;
#pragma unroll
      for (int u = 0; u < KEYS_PER_THREAD; ++u)
      {
        atomicAdd(&warp_histograms[Digit(keys[u])][part], 1);
      }

      // sum different parts;
      // no extra work is necessary if NUM_PARTS == 1
      if (NUM_PARTS > 1)
      {
        WARP_SYNC(WARP_MASK);
        // TODO: handle RADIX_DIGITS % WARP_THREADS != 0 if it becomes necessary
        constexpr int WARP_BINS_PER_THREAD = RADIX_DIGITS / WARP_THREADS;
        int bins[WARP_BINS_PER_THREAD];
#pragma unroll
        for (int u = 0; u < WARP_BINS_PER_THREAD; ++u)
        {
          int bin = lane + u * WARP_THREADS;
          bins[u] = cub::ThreadReduce(warp_histograms[bin], ::cuda::std::plus<>{});
        }
        CTA_SYNC();

        // store the resulting histogram in shared memory
        int* warp_offsets = &s.warp_offsets[warp][0];
#pragma unroll
        for (int u = 0; u < WARP_BINS_PER_THREAD; ++u)
        {
          int bin           = lane + u * WARP_THREADS;
          warp_offsets[bin] = bins[u];
        }
      }
    }

    _CCCL_DEVICE _CCCL_FORCEINLINE void ComputeOffsetsWarpUpsweep(int (&bins)[BINS_PER_THREAD])
    {
// sum up warp-private histograms
#pragma unroll
      for (int u = 0; u < BINS_PER_THREAD; ++u)
      {
        bins[u] = 0;
        int bin = ThreadBin(u);
        if (FULL_BINS || (bin >= 0 && bin < RADIX_DIGITS))
        {
#pragma unroll
          for (int j_warp = 0; j_warp < BLOCK_WARPS; ++j_warp)
          {
            int warp_offset             = s.warp_offsets[j_warp][bin];
            s.warp_offsets[j_warp][bin] = bins[u];
            bins[u] += warp_offset;
          }
        }
      }
    }

    _CCCL_DEVICE _CCCL_FORCEINLINE void ComputeOffsetsWarpDownsweep(int (&offsets)[BINS_PER_THREAD])
    {
#pragma unroll
      for (int u = 0; u < BINS_PER_THREAD; ++u)
      {
        int bin = ThreadBin(u);
        if (FULL_BINS || (bin >= 0 && bin < RADIX_DIGITS))
        {
          int digit_offset = offsets[u];
#pragma unroll
          for (int j_warp = 0; j_warp < BLOCK_WARPS; ++j_warp)
          {
            s.warp_offsets[j_warp][bin] += digit_offset;
          }
        }
      }
    }

    _CCCL_DEVICE _CCCL_FORCEINLINE void ComputeRanksItem(
      UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], Int2Type<WARP_MATCH_ATOMIC_OR>)
    {
      // compute key ranks
      int lane_mask     = 1 << lane;
      int* warp_offsets = &s.warp_offsets[warp][0];
      int* match_masks  = &s.match_masks[warp][0];
#pragma unroll
      for (int u = 0; u < KEYS_PER_THREAD; ++u)
      {
        ::cuda::std::uint32_t bin = Digit(keys[u]);
        int* p_match_mask         = &match_masks[bin];
        atomicOr(p_match_mask, lane_mask);
        WARP_SYNC(WARP_MASK);
        int bin_mask    = *p_match_mask;
        int leader      = (WARP_THREADS - 1) - __clz(bin_mask);
        int warp_offset = 0;
        int popc        = __popc(bin_mask & LaneMaskLe());
        if (lane == leader)
        {
          // atomic is a bit faster
          warp_offset = atomicAdd(&warp_offsets[bin], popc);
        }
        warp_offset = SHFL_IDX_SYNC(warp_offset, leader, WARP_MASK);
        if (lane == leader)
        {
          *p_match_mask = 0;
        }
        WARP_SYNC(WARP_MASK);
        ranks[u] = warp_offset + popc - 1;
      }
    }

    _CCCL_DEVICE _CCCL_FORCEINLINE void
    ComputeRanksItem(UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], Int2Type<WARP_MATCH_ANY>)
    {
      // compute key ranks
      int* warp_offsets = &s.warp_offsets[warp][0];
#pragma unroll
      for (int u = 0; u < KEYS_PER_THREAD; ++u)
      {
        ::cuda::std::uint32_t bin = Digit(keys[u]);
        int bin_mask =
          detail::warp_in_block_matcher_t<RADIX_BITS, PARTIAL_WARP_THREADS, BLOCK_WARPS - 1>::match_any(bin, warp);
        int leader      = (WARP_THREADS - 1) - __clz(bin_mask);
        int warp_offset = 0;
        int popc        = __popc(bin_mask & LaneMaskLe());
        if (lane == leader)
        {
          // atomic is a bit faster
          warp_offset = atomicAdd(&warp_offsets[bin], popc);
        }
        warp_offset = SHFL_IDX_SYNC(warp_offset, leader, WARP_MASK);
        ranks[u]    = warp_offset + popc - 1;
      }
    }

    _CCCL_DEVICE _CCCL_FORCEINLINE void
    RankKeys(UnsignedBits (&keys)[KEYS_PER_THREAD],
             int (&ranks)[KEYS_PER_THREAD],
             int (&exclusive_digit_prefix)[BINS_PER_THREAD])
    {
      ComputeHistogramsWarp(keys);

      CTA_SYNC();
      int bins[BINS_PER_THREAD];
      ComputeOffsetsWarpUpsweep(bins);
      callback(bins);

      BlockScan(s.prefix_tmp).ExclusiveSum(bins, exclusive_digit_prefix);

      ComputeOffsetsWarpDownsweep(exclusive_digit_prefix);
      CTA_SYNC();
      ComputeRanksItem(keys, ranks, Int2Type<MATCH_ALGORITHM>());
    }

    _CCCL_DEVICE _CCCL_FORCEINLINE
    BlockRadixRankMatchInternal(TempStorage& temp_storage, DigitExtractorT digit_extractor, CountsCallback callback)
        : s(temp_storage)
        , digit_extractor(digit_extractor)
        , callback(callback)
        , warp(threadIdx.x / WARP_THREADS)
        , lane(LaneId())
    {}
  };

  _CCCL_DEVICE _CCCL_FORCEINLINE BlockRadixRankMatchEarlyCounts(TempStorage& temp_storage)
      : temp_storage(temp_storage)
  {}

  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT, typename CountsCallback>
  _CCCL_DEVICE _CCCL_FORCEINLINE void RankKeys(
    UnsignedBits (&keys)[KEYS_PER_THREAD],
    int (&ranks)[KEYS_PER_THREAD],
    DigitExtractorT digit_extractor,
    int (&exclusive_digit_prefix)[BINS_PER_THREAD],
    CountsCallback callback)
  {
    BlockRadixRankMatchInternal<UnsignedBits, KEYS_PER_THREAD, DigitExtractorT, CountsCallback> internal(
      temp_storage, digit_extractor, callback);
    internal.RankKeys(keys, ranks, exclusive_digit_prefix);
  }

  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  RankKeys(UnsignedBits (&keys)[KEYS_PER_THREAD],
           int (&ranks)[KEYS_PER_THREAD],
           DigitExtractorT digit_extractor,
           int (&exclusive_digit_prefix)[BINS_PER_THREAD])
  {
    using CountsCallback = BlockRadixRankEmptyCallback<BINS_PER_THREAD>;
    BlockRadixRankMatchInternal<UnsignedBits, KEYS_PER_THREAD, DigitExtractorT, CountsCallback> internal(
      temp_storage, digit_extractor, CountsCallback());
    internal.RankKeys(keys, ranks, exclusive_digit_prefix);
  }

  template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT>
  _CCCL_DEVICE _CCCL_FORCEINLINE void
  RankKeys(UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
  {
    int exclusive_digit_prefix[BINS_PER_THREAD];
    RankKeys(keys, ranks, digit_extractor, exclusive_digit_prefix);
  }
};

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
namespace detail
{

// `BlockRadixRank` doesn't conform to the typical pattern, not exposing the algorithm
// template parameter. Other algorithms don't provide the same template parameters, not allowing
// multi-dimensional thread block specializations.
//
// TODO(senior-zero) for 3.0:
// - Put existing implementations into the detail namespace
// - Support multi-dimensional thread blocks in the rest of implementations
// - Repurpose BlockRadixRank as an entry name with the algorithm template parameter
template <RadixRankAlgorithm RankAlgorithm, int BlockDimX, int RadixBits, bool IsDescending, BlockScanAlgorithm ScanAlgorithm>
using block_radix_rank_t = ::cuda::std::_If<
  RankAlgorithm == RADIX_RANK_BASIC,
  BlockRadixRank<BlockDimX, RadixBits, IsDescending, false, ScanAlgorithm>,
  ::cuda::std::_If<
    RankAlgorithm == RADIX_RANK_MEMOIZE,
    BlockRadixRank<BlockDimX, RadixBits, IsDescending, true, ScanAlgorithm>,
    ::cuda::std::_If<
      RankAlgorithm == RADIX_RANK_MATCH,
      BlockRadixRankMatch<BlockDimX, RadixBits, IsDescending, ScanAlgorithm>,
      ::cuda::std::_If<
        RankAlgorithm == RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
        BlockRadixRankMatchEarlyCounts<BlockDimX, RadixBits, IsDescending, ScanAlgorithm, WARP_MATCH_ANY>,
        BlockRadixRankMatchEarlyCounts<BlockDimX, RadixBits, IsDescending, ScanAlgorithm, WARP_MATCH_ATOMIC_OR>>>>>;

} // namespace detail
#endif // _CCCL_DOXYGEN_INVOKED

CUB_NAMESPACE_END