cub/thread/thread_store.cuh

File members: cub/thread/thread_store.cuh

/******************************************************************************
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>

CUB_NAMESPACE_BEGIN

//-----------------------------------------------------------------------------
// Tags and constants
//-----------------------------------------------------------------------------

enum CacheStoreModifier
{
  STORE_DEFAULT,
  STORE_WB,
  STORE_CG,
  STORE_CS,
  STORE_WT,
  STORE_VOLATILE,
};

template <CacheStoreModifier MODIFIER, typename OutputIteratorT, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore(OutputIteratorT itr, T val);

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document

template <int COUNT, int MAX>
struct IterateThreadStore
{
  template <CacheStoreModifier MODIFIER, typename T>
  static _CCCL_DEVICE _CCCL_FORCEINLINE void Store(T* ptr, T* vals)
  {
    ThreadStore<MODIFIER>(ptr + COUNT, vals[COUNT]);
    IterateThreadStore<COUNT + 1, MAX>::template Store<MODIFIER>(ptr, vals);
  }

  template <typename OutputIteratorT, typename T>
  static _CCCL_DEVICE _CCCL_FORCEINLINE void Dereference(OutputIteratorT ptr, T* vals)
  {
    ptr[COUNT] = vals[COUNT];
    IterateThreadStore<COUNT + 1, MAX>::Dereference(ptr, vals);
  }
};

template <int MAX>
struct IterateThreadStore<MAX, MAX>
{
  template <CacheStoreModifier MODIFIER, typename T>
  static _CCCL_DEVICE _CCCL_FORCEINLINE void Store(T* /*ptr*/, T* /*vals*/)
  {}

  template <typename OutputIteratorT, typename T>
  static _CCCL_DEVICE _CCCL_FORCEINLINE void Dereference(OutputIteratorT /*ptr*/, T* /*vals*/)
  {}
};

#  define _CUB_STORE_16(cub_modifier, ptx_modifier)                                                      \
    template <>                                                                                          \
    _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, uint4*, uint4>(uint4 * ptr, uint4 val) \
    {                                                                                                    \
      asm volatile("st." #ptx_modifier ".v4.u32 [%0], {%1, %2, %3, %4};"                                 \
                   :                                                                                     \
                   : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));                          \
    }                                                                                                    \
    template <>                                                                                          \
    _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, ulonglong2*, ulonglong2>(              \
      ulonglong2 * ptr, ulonglong2 val)                                                                  \
    {                                                                                                    \
      asm volatile("st." #ptx_modifier ".v2.u64 [%0], {%1, %2};" : : "l"(ptr), "l"(val.x), "l"(val.y));  \
    }

#  define _CUB_STORE_8(cub_modifier, ptx_modifier)                                                               \
    template <>                                                                                                  \
    _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, ushort4*, ushort4>(ushort4 * ptr, ushort4 val) \
    {                                                                                                            \
      asm volatile("st." #ptx_modifier ".v4.u16 [%0], {%1, %2, %3, %4};"                                         \
                   :                                                                                             \
                   : "l"(ptr), "h"(val.x), "h"(val.y), "h"(val.z), "h"(val.w));                                  \
    }                                                                                                            \
    template <>                                                                                                  \
    _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, uint2*, uint2>(uint2 * ptr, uint2 val)         \
    {                                                                                                            \
      asm volatile("st." #ptx_modifier ".v2.u32 [%0], {%1, %2};" : : "l"(ptr), "r"(val.x), "r"(val.y));          \
    }                                                                                                            \
    template <>                                                                                                  \
    _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, unsigned long long*, unsigned long long>(      \
      unsigned long long* ptr, unsigned long long val)                                                           \
    {                                                                                                            \
      asm volatile("st." #ptx_modifier ".u64 [%0], %1;" : : "l"(ptr), "l"(val));                                 \
    }

#  define _CUB_STORE_4(cub_modifier, ptx_modifier)                                              \
    template <>                                                                                 \
    _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, unsigned int*, unsigned int>( \
      unsigned int* ptr, unsigned int val)                                                      \
    {                                                                                           \
      asm volatile("st." #ptx_modifier ".u32 [%0], %1;" : : "l"(ptr), "r"(val));                \
    }

#  define _CUB_STORE_2(cub_modifier, ptx_modifier)                                                  \
    template <>                                                                                     \
    _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, unsigned short*, unsigned short>( \
      unsigned short* ptr, unsigned short val)                                                      \
    {                                                                                               \
      asm volatile("st." #ptx_modifier ".u16 [%0], %1;" : : "l"(ptr), "h"(val));                    \
    }

#  define _CUB_STORE_1(cub_modifier, ptx_modifier)                                                \
    template <>                                                                                   \
    _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, unsigned char*, unsigned char>( \
      unsigned char* ptr, unsigned char val)                                                      \
    {                                                                                             \
      asm volatile(                                                                               \
        "{"                                                                                       \
        "   .reg .u8 datum;"                                                                      \
        "   cvt.u8.u16 datum, %1;"                                                                \
        "   st." #ptx_modifier ".u8 [%0], datum;"                                                 \
        "}"                                                                                       \
        :                                                                                         \
        : "l"(ptr), "h"((unsigned short) val));                                                   \
    }

#  define _CUB_STORE_ALL(cub_modifier, ptx_modifier) \
    _CUB_STORE_16(cub_modifier, ptx_modifier)        \
    _CUB_STORE_8(cub_modifier, ptx_modifier)         \
    _CUB_STORE_4(cub_modifier, ptx_modifier)         \
    _CUB_STORE_2(cub_modifier, ptx_modifier)         \
    _CUB_STORE_1(cub_modifier, ptx_modifier)

_CUB_STORE_ALL(STORE_WB, wb)
_CUB_STORE_ALL(STORE_CG, cg)
_CUB_STORE_ALL(STORE_CS, cs)
_CUB_STORE_ALL(STORE_WT, wt)

// Macro cleanup
#  undef _CUB_STORE_ALL
#  undef _CUB_STORE_1
#  undef _CUB_STORE_2
#  undef _CUB_STORE_4
#  undef _CUB_STORE_8
#  undef _CUB_STORE_16

template <typename OutputIteratorT, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStore(OutputIteratorT itr, T val, Int2Type<STORE_DEFAULT> /*modifier*/, Int2Type<false> /*is_pointer*/)
{
  *itr = val;
}

template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStore(T* ptr, T val, Int2Type<STORE_DEFAULT> /*modifier*/, Int2Type<true> /*is_pointer*/)
{
  *ptr = val;
}

template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStoreVolatilePtr(T* ptr, T val, Int2Type<true> /*is_primitive*/)
{
  *reinterpret_cast<volatile T*>(ptr) = val;
}

template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStoreVolatilePtr(T* ptr, T val, Int2Type<false> /*is_primitive*/)
{
  // Create a temporary using shuffle-words, then store using volatile-words
  using VolatileWord = typename UnitWord<T>::VolatileWord;
  using ShuffleWord  = typename UnitWord<T>::ShuffleWord;

  constexpr int VOLATILE_MULTIPLE = sizeof(T) / sizeof(VolatileWord);
  constexpr int SHUFFLE_MULTIPLE  = sizeof(T) / sizeof(ShuffleWord);

  VolatileWord words[VOLATILE_MULTIPLE];

#  pragma unroll
  for (int i = 0; i < SHUFFLE_MULTIPLE; ++i)
  {
    reinterpret_cast<ShuffleWord*>(words)[i] = reinterpret_cast<ShuffleWord*>(&val)[i];
  }

  IterateThreadStore<0, VOLATILE_MULTIPLE>::Dereference(reinterpret_cast<volatile VolatileWord*>(ptr), words);
}

template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStore(T* ptr, T val, Int2Type<STORE_VOLATILE> /*modifier*/, Int2Type<true> /*is_pointer*/)
{
  ThreadStoreVolatilePtr(ptr, val, Int2Type<Traits<T>::PRIMITIVE>());
}

template <typename T, int MODIFIER>
_CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStore(T* ptr, T val, Int2Type<MODIFIER> /*modifier*/, Int2Type<true> /*is_pointer*/)
{
  // Create a temporary using shuffle-words, then store using device-words
  using DeviceWord  = typename UnitWord<T>::DeviceWord;
  using ShuffleWord = typename UnitWord<T>::ShuffleWord;

  constexpr int DEVICE_MULTIPLE  = sizeof(T) / sizeof(DeviceWord);
  constexpr int SHUFFLE_MULTIPLE = sizeof(T) / sizeof(ShuffleWord);

  DeviceWord words[DEVICE_MULTIPLE];

#  pragma unroll
  for (int i = 0; i < SHUFFLE_MULTIPLE; ++i)
  {
    reinterpret_cast<ShuffleWord*>(words)[i] = reinterpret_cast<ShuffleWord*>(&val)[i];
  }

  IterateThreadStore<0, DEVICE_MULTIPLE>::template Store<CacheStoreModifier(MODIFIER)>(
    reinterpret_cast<DeviceWord*>(ptr), words);
}

template <CacheStoreModifier MODIFIER, typename OutputIteratorT, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore(OutputIteratorT itr, T val)
{
  ThreadStore(itr, val, Int2Type<MODIFIER>(), Int2Type<std::is_pointer<OutputIteratorT>::value>());
}

#endif // _CCCL_DOXYGEN_INVOKED

CUB_NAMESPACE_END