cub/util_ptx.cuh
File members: cub/util_ptx.cuh
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cub/config.cuh>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header
#include <cub/util_debug.cuh>
#include <cub/util_type.cuh>
CUB_NAMESPACE_BEGIN
/******************************************************************************
* Inlined PTX intrinsics
******************************************************************************/
CCCL_DEPRECATED_BECAUSE("will be removed in the next major release")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int SHR_ADD(unsigned int x, unsigned int shift, unsigned int addend)
{
unsigned int ret;
asm("vshr.u32.u32.u32.clamp.add %0, %1, %2, %3;" : "=r"(ret) : "r"(x), "r"(shift), "r"(addend));
return ret;
}
CCCL_DEPRECATED_BECAUSE("will be removed in the next major release")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int SHL_ADD(unsigned int x, unsigned int shift, unsigned int addend)
{
unsigned int ret;
asm("vshl.u32.u32.u32.clamp.add %0, %1, %2, %3;" : "=r"(ret) : "r"(x), "r"(shift), "r"(addend));
return ret;
}
#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
template <typename UnsignedBits, int BYTE_LEN>
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
BFE(UnsignedBits source, unsigned int bit_start, unsigned int num_bits, Int2Type<BYTE_LEN> /*byte_len*/)
{
unsigned int bits;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(bits) : "r"((unsigned int) source), "r"(bit_start), "r"(num_bits));
return bits;
}
template <typename UnsignedBits>
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
BFE(UnsignedBits source, unsigned int bit_start, unsigned int num_bits, Int2Type<8> /*byte_len*/)
{
const unsigned long long MASK = (1ull << num_bits) - 1;
return (source >> bit_start) & MASK;
}
# if CUB_IS_INT128_ENABLED
template <typename UnsignedBits>
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
BFE(UnsignedBits source, unsigned int bit_start, unsigned int num_bits, Int2Type<16> /*byte_len*/)
{
const __uint128_t MASK = (__uint128_t{1} << num_bits) - 1;
return (source >> bit_start) & MASK;
}
# endif
#endif // _CCCL_DOXYGEN_INVOKED
template <typename UnsignedBits>
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int BFE(UnsignedBits source, unsigned int bit_start, unsigned int num_bits)
{
return BFE(source, bit_start, num_bits, Int2Type<sizeof(UnsignedBits)>());
}
CCCL_DEPRECATED_BECAUSE("will be removed in the next major release")
_CCCL_DEVICE _CCCL_FORCEINLINE void
BFI(unsigned int& ret, unsigned int x, unsigned int y, unsigned int bit_start, unsigned int num_bits)
{
asm("bfi.b32 %0, %1, %2, %3, %4;" : "=r"(ret) : "r"(y), "r"(x), "r"(bit_start), "r"(num_bits));
}
CCCL_DEPRECATED_BECAUSE("will be removed in the next major release")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z)
{
asm("vadd.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(x) : "r"(x), "r"(y), "r"(z));
return x;
}
CCCL_DEPRECATED_BECAUSE("will be removed in the next major release")
_CCCL_DEVICE _CCCL_FORCEINLINE int PRMT(unsigned int a, unsigned int b, unsigned int index)
{
int ret;
asm("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index));
return ret;
}
#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
CCCL_DEPRECATED_BECAUSE("will be removed in the next major release")
_CCCL_DEVICE _CCCL_FORCEINLINE void BAR(int count)
{
asm volatile("bar.sync 1, %0;" : : "r"(count));
}
CCCL_DEPRECATED_BECAUSE("use __syncthreads() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE void CTA_SYNC()
{
__syncthreads();
}
CCCL_DEPRECATED_BECAUSE("use __syncthreads_and() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE int CTA_SYNC_AND(int p)
{
return __syncthreads_and(p);
}
CCCL_DEPRECATED_BECAUSE("use __syncthreads_or() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE int CTA_SYNC_OR(int p)
{
return __syncthreads_or(p);
}
CCCL_DEPRECATED_BECAUSE("use __syncwarp() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE void WARP_SYNC(unsigned int member_mask)
{
__syncwarp(member_mask);
}
CCCL_DEPRECATED_BECAUSE("use __any_sync() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE int WARP_ANY(int predicate, unsigned int member_mask)
{
return __any_sync(member_mask, predicate);
}
CCCL_DEPRECATED_BECAUSE("use __all_sync() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE int WARP_ALL(int predicate, unsigned int member_mask)
{
return __all_sync(member_mask, predicate);
}
CCCL_DEPRECATED_BECAUSE("use __ballot_sync() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE int WARP_BALLOT(int predicate, unsigned int member_mask)
{
return __ballot_sync(member_mask, predicate);
}
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
SHFL_UP_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask)
{
asm volatile("shfl.sync.up.b32 %0, %1, %2, %3, %4;"
: "=r"(word)
: "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask));
return word;
}
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
SHFL_DOWN_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask)
{
asm volatile("shfl.sync.down.b32 %0, %1, %2, %3, %4;"
: "=r"(word)
: "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask));
return word;
}
CCCL_DEPRECATED_BECAUSE("use __shfl_sync() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int
SHFL_IDX_SYNC(unsigned int word, int src_lane, int flags, unsigned int member_mask)
{
asm volatile("shfl.sync.idx.b32 %0, %1, %2, %3, %4;"
: "=r"(word)
: "r"(word), "r"(src_lane), "r"(flags), "r"(member_mask));
return word;
}
CCCL_DEPRECATED_BECAUSE("use __shfl_sync() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int SHFL_IDX_SYNC(unsigned int word, int src_lane, unsigned int member_mask)
{
return __shfl_sync(member_mask, word, src_lane);
}
CCCL_DEPRECATED_BECAUSE("will be removed in the next major release")
_CCCL_DEVICE _CCCL_FORCEINLINE float FMUL_RZ(float a, float b)
{
float d;
asm("mul.rz.f32 %0, %1, %2;" : "=f"(d) : "f"(a), "f"(b));
return d;
}
CCCL_DEPRECATED_BECAUSE("will be removed in the next major release")
_CCCL_DEVICE _CCCL_FORCEINLINE float FFMA_RZ(float a, float b, float c)
{
float d;
asm("fma.rz.f32 %0, %1, %2, %3;" : "=f"(d) : "f"(a), "f"(b), "f"(c));
return d;
}
#endif // _CCCL_DOXYGEN_INVOKED
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadExit()
{
asm volatile("exit;");
}
CCCL_DEPRECATED_BECAUSE("use cuda::std::terminate() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadTrap()
{
asm volatile("trap;");
}
_CCCL_DEVICE _CCCL_FORCEINLINE int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
{
return ((block_dim_z == 1) ? 0 : (threadIdx.z * block_dim_x * block_dim_y))
+ ((block_dim_y == 1) ? 0 : (threadIdx.y * block_dim_x)) + threadIdx.x;
}
CCCL_DEPRECATED_BECAUSE("use cuda::ptx::get_sreg_laneid() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneId()
{
unsigned int ret;
asm("mov.u32 %0, %%laneid;" : "=r"(ret));
return ret;
}
CCCL_DEPRECATED_BECAUSE("use cuda::ptx::get_sreg_warpid() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int WarpId()
{
unsigned int ret;
asm("mov.u32 %0, %%warpid;" : "=r"(ret));
return ret;
}
template <int LOGICAL_WARP_THREADS, int LEGACY_PTX_ARCH = 0>
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE unsigned int WarpMask(unsigned int warp_id)
{
constexpr bool is_pow_of_two = PowerOfTwo<LOGICAL_WARP_THREADS>::VALUE;
constexpr bool is_arch_warp = LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0);
unsigned int member_mask = 0xFFFFFFFFu >> (CUB_WARP_THREADS(0) - LOGICAL_WARP_THREADS);
_CCCL_IF_CONSTEXPR (is_pow_of_two && !is_arch_warp)
{
member_mask <<= warp_id * LOGICAL_WARP_THREADS;
}
(void) warp_id;
return member_mask;
}
CCCL_DEPRECATED_BECAUSE("use cuda::ptx::get_sreg_lanemask_lt() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneMaskLt()
{
unsigned int ret;
asm("mov.u32 %0, %%lanemask_lt;" : "=r"(ret));
return ret;
}
CCCL_DEPRECATED_BECAUSE("use cuda::ptx::get_sreg_lanemask_le() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneMaskLe()
{
unsigned int ret;
asm("mov.u32 %0, %%lanemask_le;" : "=r"(ret));
return ret;
}
CCCL_DEPRECATED_BECAUSE("use cuda::ptx::get_sreg_lanemask_gt() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneMaskGt()
{
unsigned int ret;
asm("mov.u32 %0, %%lanemask_gt;" : "=r"(ret));
return ret;
}
CCCL_DEPRECATED_BECAUSE("use cuda::ptx::get_sreg_lanemask_ge() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE unsigned int LaneMaskGe()
{
unsigned int ret;
asm("mov.u32 %0, %%lanemask_ge;" : "=r"(ret));
return ret;
}
template <int LOGICAL_WARP_THREADS, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE T ShuffleUp(T input, int src_offset, int first_thread, unsigned int member_mask)
{
enum
{
SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8
};
using ShuffleWord = typename UnitWord<T>::ShuffleWord;
constexpr int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
T output;
ShuffleWord* output_alias = reinterpret_cast<ShuffleWord*>(&output);
ShuffleWord* input_alias = reinterpret_cast<ShuffleWord*>(&input);
unsigned int shuffle_word;
shuffle_word = SHFL_UP_SYNC((unsigned int) input_alias[0], src_offset, first_thread | SHFL_C, member_mask);
output_alias[0] = shuffle_word;
#pragma unroll
for (int WORD = 1; WORD < WORDS; ++WORD)
{
shuffle_word = SHFL_UP_SYNC((unsigned int) input_alias[WORD], src_offset, first_thread | SHFL_C, member_mask);
output_alias[WORD] = shuffle_word;
}
return output;
}
template <int LOGICAL_WARP_THREADS, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE T ShuffleDown(T input, int src_offset, int last_thread, unsigned int member_mask)
{
enum
{
SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8
};
using ShuffleWord = typename UnitWord<T>::ShuffleWord;
constexpr int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
T output;
ShuffleWord* output_alias = reinterpret_cast<ShuffleWord*>(&output);
ShuffleWord* input_alias = reinterpret_cast<ShuffleWord*>(&input);
unsigned int shuffle_word;
shuffle_word = SHFL_DOWN_SYNC((unsigned int) input_alias[0], src_offset, last_thread | SHFL_C, member_mask);
output_alias[0] = shuffle_word;
#pragma unroll
for (int WORD = 1; WORD < WORDS; ++WORD)
{
shuffle_word = SHFL_DOWN_SYNC((unsigned int) input_alias[WORD], src_offset, last_thread | SHFL_C, member_mask);
output_alias[WORD] = shuffle_word;
}
return output;
}
template <int LOGICAL_WARP_THREADS, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE T ShuffleIndex(T input, int src_lane, unsigned int member_mask)
{
using ShuffleWord = typename UnitWord<T>::ShuffleWord;
constexpr int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
T output;
ShuffleWord* output_alias = reinterpret_cast<ShuffleWord*>(&output);
ShuffleWord* input_alias = reinterpret_cast<ShuffleWord*>(&input);
unsigned int shuffle_word;
shuffle_word = __shfl_sync(member_mask, (unsigned int) input_alias[0], src_lane, LOGICAL_WARP_THREADS);
output_alias[0] = shuffle_word;
#pragma unroll
for (int WORD = 1; WORD < WORDS; ++WORD)
{
shuffle_word = __shfl_sync(member_mask, (unsigned int) input_alias[WORD], src_lane, LOGICAL_WARP_THREADS);
output_alias[WORD] = shuffle_word;
}
return output;
}
#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
namespace detail
{
template <int LABEL_BITS, int WARP_ACTIVE_THREADS>
struct warp_matcher_t
{
static _CCCL_DEVICE unsigned int match_any(unsigned int label)
{
return warp_matcher_t<LABEL_BITS, 32>::match_any(label) & ~(~0 << WARP_ACTIVE_THREADS);
}
};
template <int LABEL_BITS>
struct warp_matcher_t<LABEL_BITS, CUB_PTX_WARP_THREADS>
{
// match.any.sync.b32 is slower when matching a few bits
// using a ballot loop instead
static _CCCL_DEVICE unsigned int match_any(unsigned int label)
{
unsigned int retval;
// Extract masks of common threads for each bit
# pragma unroll
for (int BIT = 0; BIT < LABEL_BITS; ++BIT)
{
unsigned int mask;
unsigned int current_bit = 1 << BIT;
asm("{\n"
" .reg .pred p;\n"
" and.b32 %0, %1, %2;"
" setp.ne.u32 p, %0, 0;\n"
" vote.ballot.sync.b32 %0, p, 0xffffffff;\n"
" @!p not.b32 %0, %0;\n"
"}\n"
: "=r"(mask)
: "r"(label), "r"(current_bit));
// Remove peers who differ
retval = (BIT == 0) ? mask : retval & mask;
}
return retval;
}
};
_CCCL_DEVICE _CCCL_FORCEINLINE uint32_t LogicShiftLeft(uint32_t val, uint32_t num_bits)
{
uint32_t ret{};
asm("shl.b32 %0, %1, %2;" : "=r"(ret) : "r"(val), "r"(num_bits));
return ret;
}
_CCCL_DEVICE _CCCL_FORCEINLINE uint32_t LogicShiftRight(uint32_t val, uint32_t num_bits)
{
uint32_t ret{};
asm("shr.b32 %0, %1, %2;" : "=r"(ret) : "r"(val), "r"(num_bits));
return ret;
}
} // namespace detail
#endif // _CCCL_DOXYGEN_INVOKED
template <int LABEL_BITS, int WARP_ACTIVE_THREADS = CUB_PTX_WARP_THREADS>
inline _CCCL_DEVICE unsigned int MatchAny(unsigned int label)
{
return detail::warp_matcher_t<LABEL_BITS, WARP_ACTIVE_THREADS>::match_any(label);
}
CUB_NAMESPACE_END