/home/runner/work/cccl/cccl/cub/cub/util_ptx.cuh

File members: /home/runner/work/cccl/cccl/cub/cub/util_ptx.cuh

/******************************************************************************
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/util_debug.cuh>
#include <cub/util_type.cuh>

CUB_NAMESPACE_BEGIN

/******************************************************************************
 * PTX helper macros
 ******************************************************************************/

#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

#  if defined(_WIN64) || defined(__LP64__)
#    define __CUB_LP64__ 1
// 64-bit register modifier for inlined asm
#    define _CUB_ASM_PTR_      "l"
#    define _CUB_ASM_PTR_SIZE_ "u64"
#  else
#    define __CUB_LP64__       0
// 32-bit register modifier for inlined asm
#    define _CUB_ASM_PTR_      "r"
#    define _CUB_ASM_PTR_SIZE_ "u32"
#  endif

#endif // DOXYGEN_SHOULD_SKIP_THIS

/******************************************************************************
 * Inlined PTX intrinsics
 ******************************************************************************/

namespace detail
{
_CCCL_DEVICE _CCCL_FORCEINLINE uint32_t LogicShiftLeft(uint32_t val, uint32_t num_bits)
{
  uint32_t ret{};
  asm("shl.b32 %0, %1, %2;" : "=r"(ret) : "r"(val), "r"(num_bits));
  return ret;
}

_CCCL_DEVICE _CCCL_FORCEINLINE uint32_t LogicShiftRight(uint32_t val, uint32_t num_bits)
{
  uint32_t ret{};
  asm("shr.b32 %0, %1, %2;" : "=r"(ret) : "r"(val), "r"(num_bits));
  return ret;
}
} // namespace detail

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int SHR_ADD(unsigned int x, unsigned int shift, unsigned int addend)
{
  unsigned int ret;
  asm("vshr.u32.u32.u32.clamp.add %0, %1, %2, %3;" : "=r"(ret) : "r"(x), "r"(shift), "r"(addend));
  return ret;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int SHL_ADD(unsigned int x, unsigned int shift, unsigned int addend)
{
  unsigned int ret;
  asm("vshl.u32.u32.u32.clamp.add %0, %1, %2, %3;" : "=r"(ret) : "r"(x), "r"(shift), "r"(addend));
  return ret;
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

template <typename UnsignedBits, int BYTE_LEN>
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
BFE(UnsignedBits source, unsigned int bit_start, unsigned int num_bits, Int2Type<BYTE_LEN> /*byte_len*/)
{
  unsigned int bits;
  asm("bfe.u32 %0, %1, %2, %3;" : "=r"(bits) : "r"((unsigned int) source), "r"(bit_start), "r"(num_bits));
  return bits;
}

template <typename UnsignedBits>
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
BFE(UnsignedBits source, unsigned int bit_start, unsigned int num_bits, Int2Type<8> /*byte_len*/)
{
  const unsigned long long MASK = (1ull << num_bits) - 1;
  return (source >> bit_start) & MASK;
}

#  if CUB_IS_INT128_ENABLED

template <typename UnsignedBits>
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
BFE(UnsignedBits source, unsigned int bit_start, unsigned int num_bits, Int2Type<16> /*byte_len*/)
{
  const __uint128_t MASK = (__uint128_t{1} << num_bits) - 1;
  return (source >> bit_start) & MASK;
}
#  endif

#endif // DOXYGEN_SHOULD_SKIP_THIS

template <typename UnsignedBits>
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int BFE(UnsignedBits source, unsigned int bit_start, unsigned int num_bits)
{
  return BFE(source, bit_start, num_bits, Int2Type<sizeof(UnsignedBits)>());
}

_CCCL_DEVICE _CCCL_FORCEINLINE void
BFI(unsigned int& ret, unsigned int x, unsigned int y, unsigned int bit_start, unsigned int num_bits)
{
  asm("bfi.b32 %0, %1, %2, %3, %4;" : "=r"(ret) : "r"(y), "r"(x), "r"(bit_start), "r"(num_bits));
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z)
{
  asm("vadd.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(x) : "r"(x), "r"(y), "r"(z));
  return x;
}

_CCCL_DEVICE _CCCL_FORCEINLINE int PRMT(unsigned int a, unsigned int b, unsigned int index)
{
  int ret;
  asm("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index));
  return ret;
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

_CCCL_DEVICE _CCCL_FORCEINLINE void BAR(int count)
{
  asm volatile("bar.sync 1, %0;" : : "r"(count));
}

_CCCL_DEVICE _CCCL_FORCEINLINE void CTA_SYNC()
{
  __syncthreads();
}

_CCCL_DEVICE _CCCL_FORCEINLINE int CTA_SYNC_AND(int p)
{
  return __syncthreads_and(p);
}

_CCCL_DEVICE _CCCL_FORCEINLINE int CTA_SYNC_OR(int p)
{
  return __syncthreads_or(p);
}

_CCCL_DEVICE _CCCL_FORCEINLINE void WARP_SYNC(unsigned int member_mask)
{
  __syncwarp(member_mask);
}

_CCCL_DEVICE _CCCL_FORCEINLINE int WARP_ANY(int predicate, unsigned int member_mask)
{
  return __any_sync(member_mask, predicate);
}

_CCCL_DEVICE _CCCL_FORCEINLINE int WARP_ALL(int predicate, unsigned int member_mask)
{
  return __all_sync(member_mask, predicate);
}

_CCCL_DEVICE _CCCL_FORCEINLINE int WARP_BALLOT(int predicate, unsigned int member_mask)
{
  return __ballot_sync(member_mask, predicate);
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
SHFL_UP_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask)
{
  asm volatile("shfl.sync.up.b32 %0, %1, %2, %3, %4;"
               : "=r"(word)
               : "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask));
  return word;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
SHFL_DOWN_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask)
{
  asm volatile("shfl.sync.down.b32 %0, %1, %2, %3, %4;"
               : "=r"(word)
               : "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask));
  return word;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
SHFL_IDX_SYNC(unsigned int word, int src_lane, int flags, unsigned int member_mask)
{
  asm volatile("shfl.sync.idx.b32 %0, %1, %2, %3, %4;"
               : "=r"(word)
               : "r"(word), "r"(src_lane), "r"(flags), "r"(member_mask));
  return word;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int SHFL_IDX_SYNC(unsigned int word, int src_lane, unsigned int member_mask)
{
  return __shfl_sync(member_mask, word, src_lane);
}

_CCCL_DEVICE _CCCL_FORCEINLINE float FMUL_RZ(float a, float b)
{
  float d;
  asm("mul.rz.f32 %0, %1, %2;" : "=f"(d) : "f"(a), "f"(b));
  return d;
}

_CCCL_DEVICE _CCCL_FORCEINLINE float FFMA_RZ(float a, float b, float c)
{
  float d;
  asm("fma.rz.f32 %0, %1, %2, %3;" : "=f"(d) : "f"(a), "f"(b), "f"(c));
  return d;
}

#endif // DOXYGEN_SHOULD_SKIP_THIS

_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadExit()
{
  asm volatile("exit;");
}

_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadTrap()
{
  asm volatile("trap;");
}

_CCCL_DEVICE _CCCL_FORCEINLINE int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
{
  return ((block_dim_z == 1) ? 0 : (threadIdx.z * block_dim_x * block_dim_y))
       + ((block_dim_y == 1) ? 0 : (threadIdx.y * block_dim_x)) + threadIdx.x;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneId()
{
  unsigned int ret;
  asm("mov.u32 %0, %%laneid;" : "=r"(ret));
  return ret;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int WarpId()
{
  unsigned int ret;
  asm("mov.u32 %0, %%warpid;" : "=r"(ret));
  return ret;
}

template <int LOGICAL_WARP_THREADS, int LEGACY_PTX_ARCH = 0>
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE unsigned int WarpMask(unsigned int warp_id)
{
  constexpr bool is_pow_of_two = PowerOfTwo<LOGICAL_WARP_THREADS>::VALUE;
  constexpr bool is_arch_warp  = LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0);

  unsigned int member_mask = 0xFFFFFFFFu >> (CUB_WARP_THREADS(0) - LOGICAL_WARP_THREADS);

  _CCCL_IF_CONSTEXPR (is_pow_of_two && !is_arch_warp)
  {
    member_mask <<= warp_id * LOGICAL_WARP_THREADS;
  }
  (void) warp_id;

  return member_mask;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneMaskLt()
{
  unsigned int ret;
  asm("mov.u32 %0, %%lanemask_lt;" : "=r"(ret));
  return ret;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneMaskLe()
{
  unsigned int ret;
  asm("mov.u32 %0, %%lanemask_le;" : "=r"(ret));
  return ret;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneMaskGt()
{
  unsigned int ret;
  asm("mov.u32 %0, %%lanemask_gt;" : "=r"(ret));
  return ret;
}

_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneMaskGe()
{
  unsigned int ret;
  asm("mov.u32 %0, %%lanemask_ge;" : "=r"(ret));
  return ret;
}

template <int LOGICAL_WARP_THREADS, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE T ShuffleUp(T input, int src_offset, int first_thread, unsigned int member_mask)
{
  enum
  {
    SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8
  };

  using ShuffleWord = typename UnitWord<T>::ShuffleWord;

  constexpr int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);

  T output;
  ShuffleWord* output_alias = reinterpret_cast<ShuffleWord*>(&output);
  ShuffleWord* input_alias  = reinterpret_cast<ShuffleWord*>(&input);

  unsigned int shuffle_word;
  shuffle_word    = SHFL_UP_SYNC((unsigned int) input_alias[0], src_offset, first_thread | SHFL_C, member_mask);
  output_alias[0] = shuffle_word;

#pragma unroll
  for (int WORD = 1; WORD < WORDS; ++WORD)
  {
    shuffle_word       = SHFL_UP_SYNC((unsigned int) input_alias[WORD], src_offset, first_thread | SHFL_C, member_mask);
    output_alias[WORD] = shuffle_word;
  }

  return output;
}

template <int LOGICAL_WARP_THREADS, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE T ShuffleDown(T input, int src_offset, int last_thread, unsigned int member_mask)
{
  enum
  {
    SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8
  };

  using ShuffleWord = typename UnitWord<T>::ShuffleWord;

  constexpr int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);

  T output;
  ShuffleWord* output_alias = reinterpret_cast<ShuffleWord*>(&output);
  ShuffleWord* input_alias  = reinterpret_cast<ShuffleWord*>(&input);

  unsigned int shuffle_word;
  shuffle_word    = SHFL_DOWN_SYNC((unsigned int) input_alias[0], src_offset, last_thread | SHFL_C, member_mask);
  output_alias[0] = shuffle_word;

#pragma unroll
  for (int WORD = 1; WORD < WORDS; ++WORD)
  {
    shuffle_word = SHFL_DOWN_SYNC((unsigned int) input_alias[WORD], src_offset, last_thread | SHFL_C, member_mask);
    output_alias[WORD] = shuffle_word;
  }

  return output;
}

template <int LOGICAL_WARP_THREADS, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE T ShuffleIndex(T input, int src_lane, unsigned int member_mask)
{
  enum
  {
    SHFL_C = ((32 - LOGICAL_WARP_THREADS) << 8) | (LOGICAL_WARP_THREADS - 1)
  };

  using ShuffleWord = typename UnitWord<T>::ShuffleWord;

  constexpr int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);

  T output;
  ShuffleWord* output_alias = reinterpret_cast<ShuffleWord*>(&output);
  ShuffleWord* input_alias  = reinterpret_cast<ShuffleWord*>(&input);

  unsigned int shuffle_word;
  shuffle_word = SHFL_IDX_SYNC((unsigned int) input_alias[0], src_lane, SHFL_C, member_mask);

  output_alias[0] = shuffle_word;

#pragma unroll
  for (int WORD = 1; WORD < WORDS; ++WORD)
  {
    shuffle_word = SHFL_IDX_SYNC((unsigned int) input_alias[WORD], src_lane, SHFL_C, member_mask);

    output_alias[WORD] = shuffle_word;
  }

  return output;
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
namespace detail
{

template <int LABEL_BITS, int WARP_ACTIVE_THREADS>
struct warp_matcher_t
{
  static _CCCL_DEVICE unsigned int match_any(unsigned int label)
  {
    return warp_matcher_t<LABEL_BITS, 32>::match_any(label) & ~(~0 << WARP_ACTIVE_THREADS);
  }
};

template <int LABEL_BITS>
struct warp_matcher_t<LABEL_BITS, CUB_PTX_WARP_THREADS>
{
  // match.any.sync.b32 is slower when matching a few bits
  // using a ballot loop instead
  static _CCCL_DEVICE unsigned int match_any(unsigned int label)
  {
    unsigned int retval;

// Extract masks of common threads for each bit
#  pragma unroll
    for (int BIT = 0; BIT < LABEL_BITS; ++BIT)
    {
      unsigned int mask;
      unsigned int current_bit = 1 << BIT;
      asm("{\n"
          "    .reg .pred p;\n"
          "    and.b32 %0, %1, %2;"
          "    setp.ne.u32 p, %0, 0;\n"
          "    vote.ballot.sync.b32 %0, p, 0xffffffff;\n"
          "    @!p not.b32 %0, %0;\n"
          "}\n"
          : "=r"(mask)
          : "r"(label), "r"(current_bit));

      // Remove peers who differ
      retval = (BIT == 0) ? mask : retval & mask;
    }

    return retval;
  }
};

} // namespace detail
#endif // DOXYGEN_SHOULD_SKIP_THIS

template <int LABEL_BITS, int WARP_ACTIVE_THREADS = CUB_PTX_WARP_THREADS>
inline _CCCL_DEVICE unsigned int MatchAny(unsigned int label)
{
  return detail::warp_matcher_t<LABEL_BITS, WARP_ACTIVE_THREADS>::match_any(label);
}

CUB_NAMESPACE_END